Partitions of Leaves in rooted Tree using Dynamic programming

By | November 14, 2024

Given a Rooted Tree with n node, labeled from 1 to n and tree is rooted at node 1. Parent of i-th node is p[i], i is starting from 2 as 1 is always root node. Let f(L) denotes the smallest connected subgraph that contains all leaves L. The task is to count the number of ways to partition the leaves such that for any two different set x,y of partition f(x) and f(y) are disjoint.

Examples:

Input : n=5
parent_array=[1,1,1,1]
Output : 12
Nodes with same color lies in same smallest connected subgraph.
In other partitions smallest connected subgraphs are not disjoint
so we will not count those partitions.

Input :n=10
parent_array=[1,2,3,4,5,6,7,8,9]
Output : 1
The only leaf is node 10, so there is only one partition

Approach:

We can solve this problem by using tree DP. Here, we extend the partition to include nodes in the f values of the leaf set. The dp state counts the number of partitions of leaves in the subtree rooted at node i subject to some conditions. We keep 3 values for each node.
dp[i][0] will count the number of ways to partition the leaves given that node i does not belong any partition.
dp[i][1] is the number of ways given that node i is in some partition, but needs to be connected above because it is only directly connected to one child below.
dp[i][2] is the number of ways given that node i is in some partition, but does not need to connect above.

Note that all three cases are disjoint. Initially as a base case dp[leaf][2] is 1 and all other values zero. To compute the dp value of a node, we iterate through its direct children. If we choose to include the child in the same partition, there are dp[child][1]+dp[child][2] ways to choose this. Otherwise, if the child is in a different partition, then there are dp[child][0]+dp[child][2] to choose this. The final answer is dp[0][0]+dp[0][2]

Here leaf node has to be connected above.
Let's suppose we have a node and is directly connected to one child below
and all leaves in the same partition set are in it's subtree.
We can remove that node from the f value to get a smaller connected subgraph that contains
all the leaves.
Thus, if a node is directly connected to one child below, then that means there must be leaf in
the same partition set but in a different subtree,so it must connect above.

Implementation in Python

Below is the implementation of the above approach:

# Python3 implementation of the approach  

# Function to count the partitions
def Count_Partition(dp, stack, col, children, pop, push):

    # Iterating over all nodes
    while stack:
        # Pop the current node
        x = pop()

        # Check the column array and push current node if it's None
        if col[x] is None:
            push(x)
            col[x] = 1

            # Push all child nodes of the current node
            for y in children[x]:
                push(y)

        # If column array is not None 
        else:
            # If the current node has children, initialize dp[x] to count the ways to partition
            if children[x]:
                dp[x] = (1, 0, 0)
            else:
                dp[x] = (0, 0, 1)

            # Iterate through direct children and update dp based on conditions
            for y in children[x]:
                dp[x] = (
                    dp[x][0] * (dp[y][0] + dp[y][2]) % mod, 
                    (dp[x][0] * (dp[y][1] + dp[y][2]) + dp[x][1] * (dp[y][0] + dp[y][2])) % mod, 
                    ((dp[x][1] + dp[x][2]) * (dp[y][1] + dp[y][2]) + dp[x][2] * (dp[y][0] + dp[y][2])) % mod
                )

    # Final answer
    return (dp[1][0] + dp[1][2]) % mod

# Function to initialize the tree and partitions
def Partition(n, p):
    p = [-1, -1] + p

    # Children array, index is the parent
    children = [[] for _ in range(n + 1)]
    for i in range(2, n + 1):
        children[p[i]].append(i)

    # Stack initialization
    stack = []

    # Push operation
    push = stack.append

    # Pop operation
    pop = stack.pop

    # Push the root node in the stack
    push(1)

    col = [None] * (n + 1)

    # dp matrix to count the partition
    dp = [None for _ in range(n + 1)]
    return Count_Partition(dp, stack, col, children, pop, push)

# Driver Code
if __name__ == "__main__":
    # Total number of nodes in the tree
    n = 5
    # Parent array: index is nodes and value at the index is parent (2-based indexing)
    p = [1, 1, 1, 1]
    mod = 998244353
    print(Partition(n, p))

Output:

 
12
Author: Mithlesh Upadhyay

Mithlesh Upadhyay is a Computer Science and AI expert from Madhya Pradesh with strong academic background (BE in CSE and M.Tech in AI) and over six years of experience in technical content development. He has contributed tech articles, led teams, and worked in Full Stack Development and Data Science. He founded the w3colleges.org portal for learning resources.