CODE

class Solution:
    def minCostConnectPoints(self, points: List[List[int]]) -> int:
        ############################################
        #### Prims algo | Minimum spanning tree ####
        ############################################
        N = len(points)
        # Adjaceny list
        adj = {i: [] for i in range(N)}  # i : list of [cost, node]
        for i in range(N):
            (x1, y1) = points[i]
            for j in range(i + 1, N):
                (x2, y2) = points[j]
                dist = abs(x1 - x2) + abs(y1 - y2)
                adj[i].append((dist, j))
                adj[j].append((dist, i))
                
        # For each point
        # 1. Add the distance for all the neighbor 
        # 2. Find the neighbor with the minimum distance from it.
        # 3. Add it to visited and also add its neighbors to the "frontier"
        # 4. Frontier is the set of all the candidate edges the we
        #    might want to connect
        
        # Initially the point zero has an edge to itself with distance 0
        mst = [(0, 0)]

        # Nodes that we have already added to the MST
        visited = set()
        result = 0
        while len(visited) < len(points):
            dist, point = heapq.heappop(mst)
            # Already added the point to MST
            if point in visited:
                continue
                
            visited.add(point)
            result += dist
                
            # Add all it neighbors to the frontier
            for neigh_dist, neigh in adj[point]:
                if neigh not in visited:
                    heapq.heappush(mst, (neigh_dist, neigh))
        return result
            
        

Last updated