Minimum Flips in Binary Tree to Get Result

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution:
    def minimumFlips(self, root: Optional[TreeNode], result: bool) -> int:
        def dfs(node):
            if node.val == 0:
                return 1, 0  # (flips needed for true, flips needed for false)
            if node.val == 1:
                return 0, 1  # (flips needed for true, flips needed for false)

            if node.val == 5: # NOT operation
                child_node = node.left if node.left else node.right
                true_flips, false_flips = dfs(child_node)
                return false_flips, true_flips

            left_true, left_false = dfs(node.left)
            right_true, right_false = dfs(node.right)

            if node.val == 2: # OR operation
                return min(left_true + right_true, left_false + right_true, left_true + right_false), left_false + right_false

            if node.val == 3: # AND operation
                return left_true + right_true, min(left_true + right_false, left_false + right_true, left_false + right_false)

            if node.val == 4: # XOR operation
                return min(left_false + right_true, left_true + right_false), min(left_true + right_true, left_false + right_false)

        true_flips, false_flips = dfs(root)
        return true_flips if result else false_flips