백준 최소 신장 트리(MST)

Categories: Tags:


서론

최소 신장 트리

신장 트리(Spanning tree)에 대한 개념을 몰랐기 때문에 학습하고 풀었습니다. 학습한 후에는 그렇게 어려운 문제는 아니었습니다. Union Find 개념이 사용됩니다. 이렇게 한 개씩 알고리즘을 배워가면 나중에 아는 알고리즘이 나올 거라고 생각합니다.

신장 트리(Spanning tree)

연결 그래프의 부분 그래프로서 그 그래프의 모든 정점과 간선의 부분 집합으로 구성되는 트리. 모든 노드는 적어도 하나의 간선에 연결되어 있어야 한다.

쉽게 말하자면 그래프에서 모든 노드 N개를 연결하는 간선 N-1개를 가진 트리입니다. 그리고 그중에서 최소 신장 트리는 (Minimum Spanning Tree) 그 간선들의 합이 최소가 되는 신장트리입니다. MST를 구하는 문제가 자주 나옵니다.

MST를 구하는 방법으로는 대표적으로 Kruskal, Prim 알고리즘이 있습니다. Kruskal은 간선의 개수가 적을 때, Prim은 노드의 개수가 적을 때 최적화 되어있다고 하는데 저는 그냥 다 Kruskal로 풀었습니다. Kruskal 알고리즘을 구현하는 방법은 다음과 같습니다.

  1. 기본적으로 start, edge, weight를 가지는 data-set을 만들어주어야 합니다.
  2. data-set을 weight 기준으로 오름차순 정렬하여 최소 weight를 가지는 edge를 찾습니다. 이 edge를 light edge라고 합니다.
  3. light edge를 연결했을 때도 그래프가 트리 속성을 만족한다면 연결합니다. 이 edge 연결을 safe 하다고 합니다. (여기서 Union Find 개념이 사용됩니다.)

간략하게 말하자면 Kruskal 알고리즘은 safe게 연결할 수 있는 light edge를 반복해서 찾는 알고리즘입니다.

제한 조건

  • 시간제한

위의 방법에서 1, 2를 실행하는데 힙 정렬을 사용하면 O(ElogE)의 시간복잡도를 가지게 됩니다. 그리고 3번에서 safe 한지를 판단할 때 Union Find를 사용하는데 이는 앞서서 배웠듯이 아크만 함수에 의해 상수처럼 취급되기 때문에 MST를 구하는 Kruskal알고리즘의 시간복잡도는 O(ElogE)입니다.

  • 메모리 제한

weight 순으로 오름차순 정렬을 할 때 heap을 사용하게 되는데 문제의 조건에 맞춰서 heap에 데이터가 쌓이는 양만 잘 생각하면 문제 없습니다.

9372번: 상근이의 여행

weight가 없는 Spanning tree 구현 문제입니다. weight가 없으므로 정렬할 필요가 없고 그냥 입력받는 바로바로 cycle을 만들지 않는지 확인하면서 연결하면 됩니다.

import sys
input = sys.stdin.readline


def find(a):
    if parent[a] == a:
        return a
    parent[a] = find(parent[a])
    return parent[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a == b:
        return
    parent[a] = b


for __ in range(int(input())):
    cnt = 0
    N, M = map(int, input().split())
    parent = [i for i in range(N+1)]
    for __ in range(M):
        a, b = map(int, input().split())
        if find(a) == find(b):
            continue
        union(a, b)
        cnt += 1
    print(cnt)

1197번: 최소 스패닝 트리

정확하게 MST를 구하는 문제입니다. 위에서 설명한 대로 하면 됩니다. MST 개념에 대한 구현을 연습해보기 좋은 문제입니다.

import sys
import heapq
input = sys.stdin.readline


def find(a):
    if parent[a] == a:
        return a
    parent[a] = find(parent[a])
    return parent[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a == b:
        return
    parent[a] = b


wsum = 0
heap = []
V, E = map(int, input().split())
parent = [i for i in range(V+1)]

for __ in range(E):
    a, b, c = map(int, input().split())
    heapq.heappush(heap, [c, a, b])

while heap:
    c, a, b = heapq.heappop(heap)
    if find(a) == find(b):
        continue
    union(a, b)
    wsum += c

print(wsum)

4386번: 별자리 만들기

위와 똑같은 MST 개념을 묻는 문제이지만 인풋이 start, end, weight(edge 정보)로 주어지지 않습니다. 따라서 인풋을 edge 정보로 변환하여야 하는데 n개의 좌표가 자신 빼고 모든 n-1개의 좌표와 대응하니 edge 개수는 n(n-1)/2가 됩니다. n^2에서 뭔가 비용을 초과할 수 있을 것 같기 때문에 시간제한이나 메모리 제한을 생각해야 합니다. 문제에선 n=100이 최댓값이기 때문에 충분히 통과할 수 있습니다.

import sys
import heapq
input = sys.stdin.readline

parent = []


def dist(a, b):
    x1, y1 = a
    x2, y2 = b
    return ((x1-x2)**2 + (y1-y2)**2)**0.5


def find(a):
    if a == parent[a]:
        return a
    parent[a] = find(parent[a])
    return parent[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a == b:
        return
    parent[b] = a


n = int(input())
v = []
heap = []
parent = [i for i in range(n)]
wsum = 0

for i in range(n):
    x, y = map(float, input().split())
    v.append([x, y])

for i in range(len(v)):
    for j in range(i+1, len(v)):
        heapq.heappush(heap, [dist(v[i], v[j]), i, j])

while heap:
    w, a, b = heapq.heappop(heap)
    if find(a) == find(b):
        continue
    union(a, b)
    wsum += w

print(wsum)

1744번: 우주신과의 교감

역시 edge 정보를 주지 않았고 변환하여 만들어야 합니다. 이번엔 n의 최댓값이 1,000이기 때문에 제곱을 하면 조금 비용이 큰데 이미 연결된 정보로 비용을 줄여야 하는 문제였습니다. 이미 연결되었기 때문에 이 부분은 계산을 하면 안되는데 인풋을 받는 곳에서 미리 Union을 해서 해결해 주었습니다. 똑같은 MST 문제입니다.

import sys
import heapq
input = sys.stdin.readline

parent = []


def dist(a, b):
    x1, y1 = a
    x2, y2 = b
    return ((x1-x2)**2 + (y1-y2)**2)**0.5


def find(a):
    if a == parent[a]:
        return a
    parent[a] = find(parent[a])
    return parent[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a == b:
        return
    parent[b] = a


N, M = map(int, input().split())
v = []
heap = []
parent = [i for i in range(N)]
wsum = 0

for i in range(N):
    x, y = map(float, input().split())
    v.append([x, y])
for i in range(M):
    a, b = map(int, input().split())
    union(a-1, b-1)

for i in range(len(v)):
    for j in range(i+1, len(v)):
        heapq.heappush(heap, [dist(v[i], v[j]), i, j])

while heap:
    w, a, b = heapq.heappop(heap)
    if find(a) == find(b):
        continue
    union(a, b)
    wsum += w

print(f'{wsum:.2f}')

2887번: 행성 터널

앞선 두 문제와 같은 케이스인데 이번엔 N이 100,000으로 n^2를 할 수 없는 상황입니다. 따라서 edge 정보를 수를 줄여서 넣어야 하는데 이 부분은 구글링을 통해서 해결했습니다. 이렇게 MST라는 개념을 적용하는 것뿐만 아니라 다른 개념이 들어오면 문제가 너무 어려워지는 것 같습니다. 특히 이렇게 아이디어를 사용해야 하는 것은 더 어렵습니다.

기존에 했던 방법으로 풀게 되면 n^2이 되어 heap에 넣을 시 메모리 초과가 나오게 되는데 이 부분은 문제의 조건을 통해서 해결할 수 있습니다. weight를 구하는 부분이 각 좌표들의 차이의 최솟값이기 때문에 좌표별로 정렬을 해서 빼게 되면 건너 건너 한 개만 heap에 들어가는 후보군으로 두면 되기 때문에 개수가 3(N-1)으로 줄어들게 됩니다.

import sys
import heapq
input = sys.stdin.readline


def find(a):
    if a == parent[a]:
        return a
    parent[a] = find(parent[a])
    return parent[a]


def union(a, b):
    a = find(a)
    b = find(b)
    if a == b:
        return
    parent[b] = a


N = int(input())
v = []
heap = []
parent = [i for i in range(N)]
wsum = 0

for i in range(N):
    x, y, z = map(int, input().split())
    v.append([x, y, z, i])

# 엄청난 아이디어의 부분
for i in range(3):
    v.sort(key=lambda x: x[i])
    for j in range(1, N):
        heapq.heappush(
            heap, (abs(v[j - 1][i] - v[j][i]), v[j - 1][3], v[j][3]))


while heap:
    w, a, b = heapq.heappop(heap)
    if find(a) == find(b):
        continue
    union(a, b)
    wsum += w

print(wsum)

생각

마지막 문제 같은 경우는 사실상 아이디어가 없으면 MST를 잘 구현해도 메모리 제한이나 시간제한에 걸려서 풀 수가 없습니다. 이런 부분의 아이디어를 떠올리는 건 어떻게 훈련해야 할까요?

*추가

17422번을 풀면서 새로운 걸 배웠습니다. 지금까지의 문제들은 모든 분리된 노드들이 연결된다는 보장이 있었기 때문에 별도의 처리를 하지 않아도 맞았지만 17422번처럼 문제의 조건에 의해서 분리된 노드들이 연결된다는 보장이 없을 수 있습니다.

이럴 때는 분리된 노드들이 MST가 될 때 간선의 개수는 N-1개라는 조건을 사용하면 됩니다. 만약 edge_list를 탐색하며 MST를 만들었는데 간선의 개수가 N-1일 경우 분리된 노드들 중 연결되지 않은 것이 있다는 뜻입니다.


BOJ내 다른 글 보기

댓글남기기