diff --git a/src/graphs/dijkstra.py b/src/graphs/dijkstra.py new file mode 100644 index 0000000..dd6925b --- /dev/null +++ b/src/graphs/dijkstra.py @@ -0,0 +1,76 @@ +""" +Dijkstra's Shortest Path Algorithm + +Finds the shortest path from a source vertex to all other vertices +in a weighted graph with non-negative edge weights. + +Time Complexity: O((V + E) log V) with a min-heap +Space Complexity: O(V) +""" + +import heapq +from typing import Dict, List, Tuple + + +def dijkstra( + graph: Dict[str, List[Tuple[str, float]]], source: str +) -> Tuple[Dict[str, float], Dict[str, str | None]]: + """ + Compute shortest distances from source to all reachable vertices. + + Args: + graph: Adjacency list where graph[u] = [(v, weight), ...]. + Every vertex that appears as a neighbour must also be a key + (even if its list is empty). + source: Starting vertex. + + Returns: + distances: {vertex: shortest_distance} (unreachable → float('inf')) + previous: {vertex: predecessor} for path reconstruction (source → None) + """ + distances: Dict[str, float] = {v: float("inf") for v in graph} + previous: Dict[str, str | None] = {v: None for v in graph} + distances[source] = 0.0 + + # Min-heap: (distance, vertex) + heap: List[Tuple[float, str]] = [(0.0, source)] + + while heap: + dist_u, u = heapq.heappop(heap) + + # Skip stale entries + if dist_u > distances[u]: + continue + + for v, weight in graph[u]: + alt = dist_u + weight + if alt < distances[v]: + distances[v] = alt + previous[v] = u + heapq.heappush(heap, (alt, v)) + + return distances, previous + + +def shortest_path( + graph: Dict[str, List[Tuple[str, float]]], source: str, target: str +) -> Tuple[float, List[str]]: + """ + Return the shortest distance and path from source to target. + + Returns: + (distance, path) where path is a list of vertices from source to target. + If unreachable, returns (float('inf'), []). + """ + distances, previous = dijkstra(graph, source) + + if distances[target] == float("inf"): + return float("inf"), [] + + path: List[str] = [] + current: str | None = target + while current is not None: + path.append(current) + current = previous[current] + + return distances[target], list(reversed(path)) diff --git a/tests/test_dijkstra.py b/tests/test_dijkstra.py new file mode 100644 index 0000000..1aaeb53 --- /dev/null +++ b/tests/test_dijkstra.py @@ -0,0 +1,70 @@ +"""Tests for Dijkstra's shortest path algorithm.""" + +import pytest +from src.graphs.dijkstra import dijkstra, shortest_path + + +@pytest.fixture +def simple_graph(): + """ + A -> B (1) -> D (3) + A -> C (4) + B -> C (2) + C -> D (1) + """ + return { + "A": [("B", 1), ("C", 4)], + "B": [("C", 2), ("D", 3)], + "C": [("D", 1)], + "D": [], + } + + +class TestDijkstra: + def test_distances(self, simple_graph): + dist, _ = dijkstra(simple_graph, "A") + assert dist["A"] == 0 + assert dist["B"] == 1 + assert dist["C"] == 3 # A->B->C + assert dist["D"] == 4 # A->B->C->D + + def test_previous_pointers(self, simple_graph): + _, prev = dijkstra(simple_graph, "A") + assert prev["A"] is None + assert prev["B"] == "A" + assert prev["C"] == "B" + assert prev["D"] in ("B", "C") # both paths cost 4 + + def test_unreachable_vertex(self): + graph = {"A": [("B", 1)], "B": [], "C": []} + dist, _ = dijkstra(graph, "A") + assert dist["C"] == float("inf") + + def test_single_vertex(self): + graph = {"X": []} + dist, prev = dijkstra(graph, "X") + assert dist["X"] == 0 + assert prev["X"] is None + + +class TestShortestPath: + def test_path_reconstruction(self, simple_graph): + cost, path = shortest_path(simple_graph, "A", "D") + assert cost == 4 + assert path in (["A", "B", "C", "D"], ["A", "B", "D"]) # tie + + def test_direct_neighbour(self, simple_graph): + cost, path = shortest_path(simple_graph, "A", "B") + assert cost == 1 + assert path == ["A", "B"] + + def test_same_source_target(self, simple_graph): + cost, path = shortest_path(simple_graph, "A", "A") + assert cost == 0 + assert path == ["A"] + + def test_unreachable_returns_empty(self): + graph = {"A": [], "B": []} + cost, path = shortest_path(graph, "A", "B") + assert cost == float("inf") + assert path == []