Advanced Graph Algorithms: Topological Sorting, Shortest Path, Minimum Spanning Tree, and Disjoint Set Union with Code Examples
Topological Sorting
Kahn's Algorithm
import java.util.*;
public class KahnTopoSort {
public static void main(String[] args) {
Vertex cs101 = new Vertex("CS101");
Vertex cs102 = new Vertex("CS102");
Vertex cs201 = new Vertex("CS201");
Vertex cs301 = new Vertex("CS301");
Vertex cs401 = new Vertex("CS401");
Vertex cs202 = new Vertex("CS202");
Vertex cs501 = new Vertex("CS501");
cs101.edges = List.of(new Edge(cs201));
cs102.edges = List.of(new Edge(cs201));
cs201.edges = List.of(new Edge(cs301));
cs202.edges = List.of(new Edge(cs301));
cs301.edges = List.of(new Edge(cs401));
cs401.edges = List.of(new Edge(cs501));
cs501.edges = List.of();
List<Vertex> courseGraph = List.of(cs101, cs102, cs201, cs301, cs401, cs202, cs501);
Map<Vertex, Integer> inDegreeCount = new HashMap<>();
for (Vertex node : courseGraph) {
inDegreeCount.put(node, 0);
}
for (Vertex node : courseGraph) {
for (Edge e : node.edges) {
inDegreeCount.merge(e.target, 1, Integer::sum);
}
}
Deque<Vertex> zeroInQueue = new ArrayDeque<>();
for (Vertex node : courseGraph) {
if (inDegreeCount.get(node) == 0) {
zeroInQueue.offer(node);
}
}
List<String> topoOrder = new ArrayList<>();
while (!zeroInQueue.isEmpty()) {
Vertex current = zeroInQueue.poll();
topoOrder.add(current.label);
for (Edge e : current.edges) {
int newIn = inDegreeCount.merge(e.target, -1, Integer::sum);
if (newIn == 0) {
zeroInQueue.offer(e.target);
}
}
}
if (topoOrder.size() != courseGraph.size()) {
System.out.println("Cycle detected in graph");
} else {
topoOrder.forEach(System.out::println);
}
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
Vertex(String l) { label = l; }
}
static class Edge {
Vertex target;
Edge(Vertex t) { target = t; }
}
}
Depth-First Search (DFS) Approach
import java.util.*;
public class DFSTopoSort {
public static void main(String[] args) {
Vertex cs101 = new Vertex("CS101");
Vertex cs102 = new Vertex("CS102");
Vertex cs201 = new Vertex("CS201");
Vertex cs301 = new Vertex("CS301");
Vertex cs401 = new Vertex("CS401");
Vertex cs202 = new Vertex("CS202");
Vertex cs501 = new Vertex("CS501");
cs101.edges = List.of(new Edge(cs201));
cs102.edges = List.of(new Edge(cs201));
cs201.edges = List.of(new Edge(cs301));
cs202.edges = List.of(new Edge(cs301));
cs301.edges = List.of(new Edge(cs401));
cs401.edges = List.of(new Edge(cs501));
cs501.edges = List.of();
List<Vertex> courseGraph = List.of(cs101, cs102, cs201, cs301, cs401, cs202, cs501);
LinkedList<String> reverseResult = new LinkedList<>();
boolean hasCycle = false;
for (Vertex node : courseGraph) {
if (node.state == 0) {
try {
dfs(node, reverseResult);
} catch (IllegalStateException e) {
hasCycle = true;
break;
}
}
}
if (hasCycle) {
System.out.println("Cycle detected in graph");
} else {
reverseResult.forEach(System.out::println);
}
}
private static void dfs(Vertex current, LinkedList<String> stack) {
if (current.state == 2) return;
if (current.state == 1) throw new IllegalStateException("Cycle found");
current.state = 1;
for (Edge e : current.edges) {
dfs(e.target, stack);
}
current.state = 2;
stack.push(current.label);
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
int state = 0; // 0: unvisited, 1: visiting, 2: visited
Vertex(String l) { label = l; }
}
static class Edge {
Vertex target;
Edge(Vertex t) { target = t; }
}
}
Shortest Path Algorithms
Dijkstra's Algorithm
- Initialize distance for source to 0, all others to infinity. Mark all nodes as unvisited.
- Select unvisited node with smallest tentative distance as current.
- For each unvisited neighbor of current, compute new tentative distance:
current.dist + edge.weight. Update neighbor's distance and predecessor if smaller. - Mark current as visited. Repeat until all nodes are visited or target is reached.
Basic Implementation
import java.util.*;
public class BasicDijkstra {
public static void main(String[] args) {
Vertex a = new Vertex("A");
Vertex b = new Vertex("B");
Vertex c = new Vertex("C");
Vertex d = new Vertex("D");
Vertex e = new Vertex("E");
Vertex f = new Vertex("F");
a.edges = List.of(new Edge(c, 9), new Edge(b, 7), new Edge(f, 14));
b.edges = List.of(new Edge(d, 15));
c.edges = List.of(new Edge(d, 11), new Edge(f, 2));
d.edges = List.of(new Edge(e, 6));
e.edges = List.of();
f.edges = List.of(new Edge(e, 9));
List<Vertex> graph = List.of(a, b, c, d, e, f);
computeShortestPaths(graph, a);
graph.forEach(v -> System.out.printf("%s: %d (prev: %s)%n", v.label, v.dist, v.prev == null ? "none" : v.prev.label));
}
private static void computeShortestPaths(List<Vertex> graph, Vertex source) {
List<Vertex> unvisited = new ArrayList<>(graph);
source.dist = 0;
while (!unvisited.isEmpty()) {
Vertex current = getMinDistNode(unvisited);
relaxNeighbors(current, unvisited);
unvisited.remove(current);
}
}
private static Vertex getMinDistNode(List<Vertex> unvisited) {
Vertex minNode = unvisited.get(0);
for (int i = 1; i < unvisited.size(); i++) {
if (unvisited.get(i).dist < minNode.dist) {
minNode = unvisited.get(i);
}
}
return minNode;
}
private static void relaxNeighbors(Vertex current, List<Vertex> unvisited) {
for (Edge e : current.edges) {
Vertex neighbor = e.target;
if (unvisited.contains(neighbor)) {
int newDist = current.dist + e.weight;
if (newDist < neighbor.dist) {
neighbor.dist = newDist;
neighbor.prev = current;
}
}
}
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
int dist = Integer.MAX_VALUE;
Vertex prev = null;
Vertex(String l) { label = l; }
@Override public String toString() { return label; }
}
static class Edge {
Vertex target;
int weight;
Edge(Vertex t, int w) { target = t; weight = w; }
}
}
Priority Queue Optimization
import java.util.*;
public class PQDijkstra {
public static void main(String[] args) {
Vertex a = new Vertex("A");
Vertex b = new Vertex("B");
Vertex c = new Vertex("C");
Vertex d = new Vertex("D");
Vertex e = new Vertex("E");
Vertex f = new Vertex("F");
a.edges = List.of(new Edge(c, 9), new Edge(b, 7), new Edge(f, 14));
b.edges = List.of(new Edge(d, 15));
c.edges = List.of(new Edge(d, 11), new Edge(f, 2));
d.edges = List.of(new Edge(e, 6));
e.edges = List.of();
f.edges = List.of(new Edge(e, 9));
List<Vertex> graph = List.of(a, b, c, d, e, f);
computeShortestPaths(graph, a);
graph.forEach(v -> System.out.printf("%s: %d (prev: %s)%n", v.label, v.dist, v.prev == null ? "none" : v.prev.label));
}
private static void computeShortestPaths(List<Vertex> graph, Vertex source) {
PriorityQueue<Vertex> pq = new PriorityQueue<>(Comparator.comparingInt(v -> v.dist));
source.dist = 0;
pq.offer(source);
while (!pq.isEmpty()) {
Vertex current = pq.poll();
if (current.visited) continue;
current.visited = true;
for (Edge e : current.edges) {
Vertex neighbor = e.target;
if (!neighbor.visited) {
int newDist = current.dist + e.weight;
if (newDist < neighbor.dist) {
neighbor.dist = newDist;
neighbor.prev = current;
pq.offer(neighbor);
}
}
}
}
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
int dist = Integer.MAX_VALUE;
Vertex prev = null;
boolean visited = false;
Vertex(String l) { label = l; }
@Override public String toString() { return label; }
}
static class Edge {
Vertex target;
int weight;
Edge(Vertex t, int w) { target = t; weight = w; }
}
}
Limitation: Dijkstra's algorithm fails with negative edge weights, as it assumes once a node is marked visited, its shortest distance is finalized.
Bellman-Ford Algorithm
- Initialize distance for source to 0, all other to infinity.
- Relax all edges
|V| - 1times (where|V|is vertex count). - Check for negative cycles: if an edge can still be relaxed after
|V| - 1iterations, a negative cycle exists.
import java.util.*;
public class BellmanFord {
public static void main(String[] args) {
// Negative cycle test case
Vertex a = new Vertex("A");
Vertex b = new Vertex("B");
Vertex c = new Vertex("C");
Vertex d = new Vertex("D");
a.edges = List.of(new Edge(b, 2));
b.edges = List.of(new Edge(c, -4));
c.edges = List.of(new Edge(d, 1), new Edge(a, 1));
d.edges = List.of();
List<Vertex> graph = List.of(a, b, c, d);
boolean hasNegativeCycle = computeShortestPaths(graph, a);
if (hasNegativeCycle) {
System.out.println("Negative-weight cycle detected");
} else {
graph.forEach(v -> System.out.printf("%s: %d (prev: %s)%n", v.label, v.dist, v.prev == null ? "none" : v.prev.label));
}
}
private static boolean computeShortestPaths(List<Vertex> graph, Vertex source) {
source.dist = 0;
int vertexCount = graph.size();
// Relax edges |V| - 1 times
for (int i = 0; i < vertexCount - 1; i++) {
boolean updated = false;
for (Vertex start : graph) {
for (Edge e : start.edges) {
Vertex end = e.target;
if (start.dist != Integer.MAX_VALUE && start.dist + e.weight < end.dist) {
end.dist = start.dist + e.weight;
end.prev = start;
updated = true;
}
}
}
if (!updated) break; // Early exit if no updates
}
// Check for negative cycles
for (Vertex start : graph) {
for (Edge e : start.edges) {
Vertex end = e.target;
if (start.dist != Integer.MAX_VALUE && start.dist + e.weight < end.dist) {
return true;
}
}
}
return false;
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
int dist = Integer.MAX_VALUE;
Vertex prev = null;
Vertex(String l) { label = l; }
}
static class Edge {
Vertex target;
int weight;
Edge(Vertex t, int w) { target = t; weight = w; }
}
}
Floyd-Warshall Algorithm
Computes shortest paths between all pairs of vertices using dynamic programming. Handles negative edge weights (but not negative cycles).
import java.util.*;
import java.util.stream.Collectors;
public class FloydWarshall {
public static void main(String[] args) {
Vertex v1 = new Vertex("v1");
Vertex v2 = new Vertex("v2");
Vertex v3 = new Vertex("v3");
Vertex v4 = new Vertex("v4");
v1.edges = List.of(new Edge(v3, -2));
v2.edges = List.of(new Edge(v1, 4), new Edge(v3, 3));
v3.edges = List.of(new Edge(v4, 2));
v4.edges = List.of(new Edge(v2, -1));
List<Vertex> graph = List.of(v1, v2, v3, v4);
Object[] result = computeAllPairsShortestPaths(graph);
int[][] distMatrix = (int[][]) result[0];
Vertex[][] prevMatrix = (Vertex[][]) result[1];
boolean hasNegativeCycle = false;
for (int i = 0; i < graph.size(); i++) {
if (distMatrix[i][i] < 0) {
hasNegativeCycle = true;
break;
}
}
if (hasNegativeCycle) {
System.out.println("Negative-weight cycle detected");
} else {
printDistanceMatrix(distMatrix, graph);
printPath(prevMatrix, graph, 1, 0); // v2 -> v1
}
}
private static Object[] computeAllPairsShortestPaths(List<Vertex> graph) {
int n = graph.size();
int[][] dist = new int[n][n];
Vertex[][] prev = new Vertex[n][n];
// Initialize distance and predecessor matrices
for (int i = 0; i < n; i++) {
Vertex start = graph.get(i);
Map<Vertex, Integer> edgeWeights = start.edges.stream()
.collect(Collectors.toMap(e -> e.target, e -> e.weight));
for (int j = 0; j < n; j++) {
Vertex end = graph.get(j);
if (i == j) {
dist[i][j] = 0;
} else {
dist[i][j] = edgeWeights.getOrDefault(end, Integer.MAX_VALUE);
prev[i][j] = edgeWeights.containsKey(end) ? start : null;
}
}
}
// Dynamic programming: k = intermediate node
for (int k = 0; k < n; k++) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (dist[i][k] != Integer.MAX_VALUE && dist[k][j] != Integer.MAX_VALUE) {
int viaK = dist[i][k] + dist[k][j];
if (viaK < dist[i][j]) {
dist[i][j] = viaK;
prev[i][j] = prev[k][j];
}
}
}
}
}
return new Object[]{dist, prev};
}
private static void printPath(Vertex[][] prev, List<Vertex> graph, int startIdx, int endIdx) {
LinkedList<String> pathStack = new LinkedList<>();
int currIdx = endIdx;
pathStack.push(graph.get(currIdx).label);
while (currIdx != startIdx) {
Vertex p = prev[startIdx][currIdx];
if (p == null) {
System.out.printf("[%s -> %s] No path exists%n", graph.get(startIdx).label, graph.get(endIdx).label);
return;
}
pathStack.push(p.label);
currIdx = graph.indexOf(p);
}
System.out.printf("[%s -> %s] %s%n", graph.get(startIdx).label, graph.get(endIdx).label, String.join(" -> ", pathStack));
}
private static void printDistanceMatrix(int[][] dist, List<Vertex> graph) {
System.out.println("All-Pairs Shortest Distances:");
System.out.print(" ");
graph.forEach(v -> System.out.printf("%4s", v.label));
System.out.println();
for (int i = 0; i < dist.length; i++) {
System.out.printf("%4s", graph.get(i).label);
for (int j = 0; j < dist[i].length; j++) {
String val = dist[i][j] == Integer.MAX_VALUE ? "∞" : String.valueOf(dist[i][j]);
System.out.printf("%4s", val);
}
System.out.println();
}
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
Vertex(String l) { label = l; }
}
static class Edge {
Vertex target;
int weight;
Edge(Vertex t, int w) { target = t; weight = w; }
}
}
Minimum Spanning Tree (MST) Algorithms
Prim's Algorithm
- Initialize MST vertex set with source, set source's key (edge weight to MST) to 0, others to infinity.
- Select vertex not in MST with smallest key, add to MST.
- Update keys of neighbors not in MST to min(current key, edge weight from new MST vertex).
- Repeat until all vertices are in MST.
import java.util.*;
public class PrimMST {
public static void main(String[] args) {
Vertex v1 = new Vertex("v1");
Vertex v2 = new Vertex("v2");
Vertex v3 = new Vertex("v3");
Vertex v4 = new Vertex("v4");
Vertex v5 = new Vertex("v5");
Vertex v6 = new Vertex("v6");
Vertex v7 = new Vertex("v7");
v1.edges = List.of(new Edge(v2, 2), new Edge(v3, 4), new Edge(v4, 1));
v2.edges = List.of(new Edge(v1, 2), new Edge(v4, 3), new Edge(v5, 10));
v3.edges = List.of(new Edge(v1, 4), new Edge(v4, 2), new Edge(v6, 5));
v4.edges = List.of(new Edge(v1, 1), new Edge(v2, 3), new Edge(v3, 2), new Edge(v5, 7), new Edge(v6, 8), new Edge(v7, 4));
v5.edges = List.of(new Edge(v2, 10), new Edge(v4, 7), new Edge(v7, 6));
v6.edges = List.of(new Edge(v3, 5), new Edge(v4, 8), new Edge(v7, 1));
v7.edges = List.of(new Edge(v4, 4), new Edge(v5, 6), new Edge(v6, 1));
List<Vertex> graph = List.of(v1, v2, v3, v4, v5, v6, v7);
List<Edge> mstEdges = buildMST(graph, v1);
int totalWeight = mstEdges.stream().mapToInt(e -> e.weight).sum();
System.out.println("MST Edges:");
mstEdges.forEach(e -> System.out.printf("%s-%s (%d)%n", e.start.label, e.end.label, e.weight));
System.out.println("Total MST Weight: " + totalWeight);
}
private static List<Edge> buildMST(List<Vertex> graph, Vertex source) {
List<Vertex> unlinked = new ArrayList<>(graph);
List<Edge> mst = new ArrayList<>();
source.key = 0;
while (!unlinked.isEmpty()) {
Vertex current = getMinKeyNode(unlinked);
unlinked.remove(current);
current.inMST = true;
if (current.mstEdge != null) {
mst.add(current.mstEdge);
}
updateNeighborKeys(current);
}
return mst;
}
private static Vertex getMinKeyNode(List<Vertex> unlinked) {
Vertex minNode = unlinked.get(0);
for (int i = 1; i < unlinked.size(); i++) {
if (unlinked.get(i).key < minNode.key) {
minNode = unlinked.get(i);
}
}
return minNode;
}
private static void updateNeighborKeys(Vertex current) {
for (Edge e : current.edges) {
Vertex neighbor = e.end;
if (!neighbor.inMST && e.weight < neighbor.key) {
neighbor.key = e.weight;
neighbor.mstEdge = new Edge(current, neighbor, e.weight);
}
}
}
static class Vertex {
String label;
List<Edge> edges = new ArrayList<>();
int key = Integer.MAX_VALUE;
boolean inMST = false;
Edge mstEdge = null;
Vertex(String l) { label = l; }
}
static class Edge {
Vertex start;
Vertex end;
int weight;
Edge(Vertex s, Vertex e, int w) { start = s; end = e; weight = w; }
}
}
Kruskal's Algorithm
- Sort all edges by weight in ascending order.
- Initialize MST edge set and DSU structure to track connected components.
- Iterate over sorted edges: add edge to MST if it connects two diffreent components (no cycle).
- Stop when MST has
|V| - 1edges.
import java.util.*;
public class KruskalMST {
public static void main(String[] args) {
int totalVertices = 7;
List<WeightedEdge> allEdges = new ArrayList<>();
allEdges.add(new WeightedEdge(0, 1, 2));
allEdges.add(new WeightedEdge(0, 2, 4));
allEdges.add(new WeightedEdge(0, 3, 1));
allEdges.add(new WeightedEdge(1, 3, 3));
allEdges.add(new WeightedEdge(1, 4, 10));
allEdges.add(new WeightedEdge(2, 3, 2));
allEdges.add(new WeightedEdge(2, 5, 5));
allEdges.add(new WeightedEdge(3, 4, 7));
allEdges.add(new WeightedEdge(3, 5, 8));
allEdges.add(new WeightedEdge(3, 6, 4));
allEdges.add(new WeightedEdge(4, 6, 6));
allEdges.add(new WeightedEdge(5, 6, 1));
List<WeightedEdge> mst = buildMST(totalVertices, allEdges);
int totalWeight = mst.stream().mapToInt(e -> e.weight).sum();
System.out.println("MST Edges:");
mst.forEach(e -> System.out.printf("v%d-v%d (%d)%n", e.u + 1, e.v + 1, e.weight));
System.out.println("Total MST Weight: " + totalWeight);
}
private static List<WeightedEdge> buildMST(int vertexCount, List<WeightedEdge> edges) {
Collections.sort(edges);
DSU dsu = new DSU(vertexCount);
List<WeightedEdge> mst = new ArrayList<>();
for (WeightedEdge e : edges) {
if (mst.size() == vertexCount - 1) break;
int rootU = dsu.findRoot(e.u);
int rootV = dsu.findRoot(e.v);
if (rootU != rootV) {
mst.add(e);
dsu.merge(rootU, rootV);
}
}
return mst;
}
static class WeightedEdge implements Comparable<WeightedEdge> {
int u, v, weight;
WeightedEdge(int u, int v, int w) { this.u = u; this.v = v; weight = w; }
@Override public int compareTo(WeightedEdge o) { return Integer.compare(this.weight, o.weight); }
}
static class DSU {
private final int[] parent;
private final int[] size;
DSU(int n) {
parent = new int[n];
size = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
size[i] = 1;
}
}
int findRoot(int x) {
if (parent[x] != x) parent[x] = findRoot(parent[x]);
return parent[x];
}
void merge(int rootX, int rootY) {
if (size[rootX] < size[rootY]) {
int temp = rootX;
rootX = rootY;
rootY = temp;
}
parent[rootY] = rootX;
size[rootX] += size[rootY];
}
}
}
Disjoint Set Union (DSU) / Union-Find
Basic Implementation
import java.util.Arrays;
public class BasicDSU {
private final int[] parent;
BasicDSU(int size) {
parent = new int[size];
for (int i = 0; i < size; i++) parent[i] = i;
}
int findRoot(int x) {
if (parent[x] == x) return x;
return findRoot(parent[x]);
}
void merge(int rootX, int rootY) {
parent[rootY] = rootX;
}
@Override public String toString() { return Arrays.toString(parent); }
}
Path Compression Optimization
int findRoot(int x) {
if (parent[x] != x) {
parent[x] = findRoot(parent[x]);
}
return parent[x];
}
Union by Size Optimization
import java.util.Arrays;
public class OptimizedDSU {
private final int[] parent;
private final int[] componentSize;
OptimizedDSU(int size) {
parent = new int[size];
componentSize = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
componentSize[i] = 1;
}
}
int findRoot(int x) {
if (parent[x] != x) parent[x] = findRoot(parent[x]);
return parent[x];
}
void merge(int x, int y) {
int rootX = findRoot(x);
int rootY = findRoot(y);
if (rootX == rootY) return;
if (componentSize[rootX] < componentSize[rootY]) {
int temp = rootX;
rootX = rootY;
rootY = temp;
}
parent[rootY] = rootX;
componentSize[rootX] += componentSize[rootY];
}
@Override public String toString() {
return String.format("Parent: %s%nComponent Sizes: %s", Arrays.toString(parent), Arrays.toString(componentSize));
}
public static void main(String[] args) {
OptimizedDSU dsu = new OptimizedDSU(5);
dsu.merge(1, 2);
dsu.merge(3, 4);
dsu.merge(1, 3);
System.out.println(dsu);
}
}
Related Practice Problems
| Problem Number | Problem Title | Relevant Algorithm(s) |
|---|---|---|
| 547 | Number of Provinces | DFS, BFS, DSU |
| 797 | All Paths From Source to Target | DFS, BFS |
| 1584 | Min Cost to Connect All Points | Prim's, Kruskal's MST |
| 743 | Network Delay Time | Dijkstra's, Bellman-Ford |
| 787 | Cheapest Flights Within K Stops | Bellman-Ford, BFS with DP |
| 207 | Course Schedule | Kahn's, DFS Topological Sort |
| 210 | Course Schedule II | Kahn's, DFS Topological Sort |