백준 2261번 분할정복
서론
앞으로 플레티넘 문제나 재밌었던 문제에 대해서 글을 남기겠습니다. 주어진 점들 중 가장 가까운 두 점을 찾는 문제입니다. 당연히 brute force로 하면 시간 조건에 걸리고 분할정복을 사용해서 풀어야 합니다. 앞서 6549번: 히스토그램에서 가장 큰 직사각형 문제에서 너무 많이 틀려서 당연히 시간 조건 고려해서 분할정복으로 풀었지만, 분할정복의 가장 중요한 부분인 분할된 부분에서의 처리를 n^2로 풀어서 시간초과로 틀렸습니다. 결국 도움받아서 풀었습니다.
제한 조건
생각해보면 당연히 단순 구현 문제가 아니라면 brute force로 풀면 시간 조건에 걸릴 것 같습니다. 몇 문제 풀다 보니까 구현 문제에서도 너무 무식하게 구현해놓으면 느려서 통과를 못 합니다.
- 시간제한 - 1초
시간제한은 1초입니다. 일반적인 컴퓨터로 1초에 1억 번 연산한다는 기준을 가지고 풉니다. 데이터의 개수는 10만 개입니다. O(n^2)로 문제를 풀 게 되면 루프 안에서 O(1)로 연산을 해도 100억 번의 연산이 필요하므로 명백하게 O(n^2)보다 빠른 알고리즘으로 푸는 문제입니다.
- 메모리 제한 - 256MB
메모리 제한은 256MB입니다. 데이터의 개수는 절댓값 1만을 넘지 않는 10만 개의 데이터입니다. 데이터는 4바이트 int형의 데이터가 한 좌표마다 2개씩 있으니 20만 개, 80만 바이트 즉 총 데이터 메모리는 800KB입니다.
데이터를 담을 메모리 조건은 충분하고 분할정복에서 틀리기 쉬운 recursion 형식으로 메모리를 복사하는 것만 주의하면 됩니다. 함수에서 데이터를 호출할 때 인자로 넘기지 말고 데이터는 전역변수로 두고 인덱스값만 넘기면서 코딩하였습니다.
분할 정복 O(nlogn)
데이터를 계속 절반씩 나누고 나눠진 부분에 대해서 처리하는 분할정복입니다. 나누어졌을 때 최솟값은 점이 세 개 이하로 남을 때이며 이때는 그냥 brute force로 각 점 사이의 최솟값을 리턴합니다. 나눠진 부분처리는 나눈 점을 기준으로 왼쪽 분할과 오른쪽 분할에서의 최솟값보다 더 작은 값이 있는지 찾는 방법입니다.
우선 데이터를 분할정복하기 위한 기준을 x축으로 정하고 x축으로 소팅했습니다. 소팅하는 메소드의 시간복잡도는 nlogn이지만 총 코드가 nlogn보다 빠를 것 같지 않았기 때문에 그냥 사용했습니다. 처음에는 부분처리 방법으로 n개의 후보군을 구하고 이 후보군에 대해서 brute force를 했기 때문에 n^2로 실패했습니다. 해답을 못 찾아서 또 도움을 받았습니다. 해답은 x축으로 정렬하고 최솟값보다 더 작은 값이 있는지만 찾은 것 처럼 y축으로도 정렬을 하고 최솟값보다 더 작은 값이 있는지 찾는 방법이었습니다. 근데 이 코드가 n^2에서 조건을 넣어서 빠른 건 알겠는데 최악의 경우는 n^2로 마찬가지인 것 같은데 정확히 어떤 시간복잡도로 표현되는지는 잘 모르겠습니다. 다음에 알게 되면 추가하겠습니다. 코드는 아래와 같습니다.
x_candidate = point_list[avg-lcnt: avg+rcnt+1]
x_candidate = sorted(x_candidate, key=lambda x: x[1])
for i in range(len(x_candidate)-1):
for j in range(i+1, len(x_candidate)):
if min_both > (x_candidate[i][1] - x_candidate[j][1])**2:
tmp = min(tmp, dist_point(x_candidate[i], x_candidate[j]))
else:
break
절반씩 분할정복을 하므로 logn, 나눠진 부분을 처리하는 방법은 최악의 경우에 n^2이지만 추가한 조건으로 인해 n으로 표현하겠습니다. (정확히 모르겠습니다) 따라서 O(nlogn)의 시간복잡도를 가지게 됩니다. 데이터를 인덱스로 참조하기 때문에 공간복잡도 문제도 없습니다.
import sys
input = sys.stdin.readline
MIN_NUM = sys.maxsize
def dist_point(a, b):
return (a[0]-b[0])**2 + (a[1]-b[1])**2
def brute(start, end):
tmp = MIN_NUM
for i in range(start, end+1):
for j in range(i+1, end+1):
tmp = min(tmp, dist_point(point_list[i], point_list[j]))
return tmp
def area_point(avg, start, end, min_both):
tmp = min_both
lcnt = 0
rcnt = 0
while avg-lcnt-1 >= start and min_both > (point_list[avg][0] - point_list[avg-lcnt-1][0])**2:
lcnt += 1
while avg + rcnt+1 <= end and min_both > (point_list[avg][0] - point_list[avg+rcnt+1][0])**2:
rcnt += 1
x_candidate = point_list[avg-lcnt: avg+rcnt+1]
x_candidate = sorted(x_candidate, key=lambda x: x[1])
for i in range(len(x_candidate)-1):
for j in range(i+1, len(x_candidate)):
if min_both > (x_candidate[i][1] - x_candidate[j][1])**2:
tmp = min(tmp, dist_point(x_candidate[i], x_candidate[j]))
else:
break
return tmp
tc = int(input())
point_list = []
for __ in range(tc):
point_list.append(list(map(int, input().split())))
point_list = sorted(point_list, key=lambda x: (x[0]))
def solution(start, end):
cnt = 1
tmp = MIN_NUM
if end - start + 1 <= 3:
return brute(start, end)
else:
avg = (start+end)//2
min_left = solution(start, avg-1)
min_right = solution(avg+1, end)
min_both = min(min_left, min_right)
tmp = area_point(avg, start, end, min_both)
return min(tmp, min_left, min_right)
print(solution(0, tc-1))
생각
분할하는 곳에서의 처리가 결국은 최악의 경우 n^2인 것 같은데 break 조건을 걸면서 정확히 어떤 식으로 시간복잡도가 줄어드는지 궁금합니다. 찾게 되거나 제가 통찰력이 생겨서 알게 되면 다음에 수정하겠습니다.
BOJ내 다른 글 보기
댓글남기기