Why union find is such a cool algorithm
One of my favorite algorithms in graph theory is Union Find. It is a beautiful algorithm for tracking disjoint sets. A disjoint set is a group of nodes in an undirected graph that are all connected to each other through some path, but disconnected from any node outside the set.
Suppose we have the following graph. The disjoint sets here are {0, 1, 2, 3} and {4, 5, 6} — clearly disconnected from each other.
flowchart LR
n0((0)) --- n1((1))
n1 --- n2((2))
n0 --- n3((3))
n4((4)) --- n5((5))
n5 --- n6((6))
Let’s say we are asked to check if two nodes u and v are within the same disjoint set, for example u = 1 and v = 3. How can we check this efficiently?
A naive approach
One way is to run BFS or DFS from u to v. This approach should be fine if we only need to check one pair. But what if we need to check a large number of pairs?
Let’s say there are V nodes and E edges, where V and E are both 5 * 10^6 (5 million). We are also given k queries, where k = 5 * 10^6 (5 million). At this scale, we need a really efficient algorithm to find the connection.
Let’s stick with the BFS/DFS approach for a moment. We can come up with an algorithm that keeps track of visited nodes and efficiently traverses the entire graph. The algorithm can then return the following mapping.
Here, parent is the one node that every node in the disjoint set can reach. We can also call this the representative of the disjoint set.
flowchart BT
n0((0))
n4((4))
n1((1)) --> n0
n2((2)) --> n0
n3((3)) --> n0
n5((5)) --> n4
n6((6)) --> n4
We can construct this mapping in O(V) time because we visit each node only once. Then for each query we just check if the two nodes have the same parent. The total time complexity is O(V + k), which is doable.
However, a problem arises when we want to keep connecting disjoint sets dynamically. Let’s say we want to connect nodes 2 and 5. The graph then looks like this.
flowchart LR
n0((0)) --- n1((1))
n1 --- n2((2))
n0 --- n3((3))
n4((4)) --- n5((5))
n5 --- n6((6))
n2 --- n5
We first look at the node -> parent mapping and check whether the two nodes have the same parent. If they don’t, we update the parents so that all nodes in the two merged sets share the same parent. I picked 0 as the new parent here. The mapping then becomes:
flowchart BT
n0((0))
n1((1)) --> n0
n2((2)) --> n0
n3((3)) --> n0
n4((4)) --> n0
n5((5)) --> n0
n6((6)) --> n0
This update can take O(V) because, in the worst case, we have to touch almost every node. Now, if we add the requirement that a new connection gets added on every query, the current approach becomes infeasible. The time complexity becomes O(V + V * k) = O(V * k). With a rough calculation, assuming a computer that handles 10^8 operations per second, we’d need (5 * 10^6) * (5 * 10^6) / 10^8 ~= 69.4 hours just to compute this.
Enter Union Find
Can we do better? It turns out there’s an algorithm that handles each graph update in nearly constant time — more specifically, in O(α(n)), where α(n) is the inverse Ackermann function. For all practical inputs, α(n) can be treated as a constant. So effectively, we can solve the whole problem in nearly O(k) time.
Union Find is also useful when you’re only given a list of edges. To run DFS or BFS you first have to build an adjacency list or adjacency matrix; with Union Find that step is unnecessary, since the union function takes an edge directly as its argument.
Code
Below is the complete code for the union find algorithm.
class UnionFind:
def __init__(self, size):
self.parent = [i for i in range(size)]
self.rank = [1] * size
def find(self, u):
if self.parent[u] == u:
return u
self.parent[u] = self.find(self.parent[u])
return self.parent[u]
def union(self, u, v):
u = self.find(u)
v = self.find(v)
if u == v:
return
if self.rank[u] < self.rank[v]:
self.parent[u] = v
self.rank[v] += self.rank[u]
else:
self.parent[v] = u
self.rank[u] += self.rank[v]
There are two main parts in this algorithm. One is union by rank and another is path compression.
Union by Rank
Let’s use the following example. The parent of the first set is 0, and the parent of the second set is 3.
flowchart TB
n0((0)) --- n1((1))
n1 --- n2((2))
n3((3)) --- n4((4))
Let’s see what happens if we don’t consider rank.
parent stores a tree where each child points upward to its parent, and the root is the representative of the disjoint set. When the union operation connects two disjoint sets, we can end up in a situation where one side contains a lot more nodes than the other.
In the worst case, the parent tree after union(2, 3) can look like the following, since we don’t check the rank. The find operation becomes slow — O(n) in the worst case.
flowchart TB
n3((3)) --- n0((0))
n3 --- n4((4))
n0 --- n1((1))
n1 --- n2((2))
Now let’s introduce the concept of rank. The rank loosely captures the height of the tree. When connecting one tree to another, we make the shorter tree’s root point to the taller tree’s root. This keeps the resulting tree shorter, like the following.
flowchart TB
n0((0)) --- n3((3))
n0 --- n1((1))
n3 --- n4((4))
n1 --- n2((2))
Path Compression
Path compression is implemented in the find function. Instead of simply returning parent[u], we recursively make every node along the path point directly to the root.
With path compression, the earlier result becomes the following. The tree is now just height 2.
flowchart TB
n0((0)) --- n3((3))
n0 --- n1((1))
n0 --- n4((4))
n0 --- n2((2))
Wrap-up
Union by rank and path compression are both fairly simple on their own — just a few lines of code each — but combined they push the per-operation cost down to effectively constant time. That’s the part that always feels a little magical to me. The inverse Ackermann function shows up as the upper bound from a tight proof, and for any input you can realistically build, α(n) is at most 4 or 5.
Beyond the toy connectivity problem in this post, Union Find shows up in Kruskal’s MST algorithm, percolation models, and connected-component labeling for images. Anywhere you’re tracking “which group does this element belong to” while merges keep happening, it’s likely the right tool.