본문 바로가기
기술스택을 쌓아보자/Python

[codewars/3kyu] 수 많은 점 중 최소 거리 조합 구하기(Merge Sort, Divide and Conquer를 통한 속도 최적화(병합정렬 알고리즘))

by 소리331 2024. 4. 30.
반응형

함께 풀어보아요~

 

Codewars - Achieve mastery through coding practice and developer mentorship

A coding practice website for all programming levels – Join a community of over 3 million developers and improve your coding skills in over 55 programming languages!

www.codewars.com

가장 가까운 거리를 가지는 두 점을 구해주세요!

이번 문제는 지문은 짧아서 좋습니다. 제목 그대로가 요구사항입니다! 

(
  (2,2), # A
  (2,8), # B
  (5,5), # C
  (6,3), # D
  (6,7), # E
  (7,4), # F
  (7,9)  # G
)
=> closest pair is: ((6,3),(7,4)) or ((7,4),(6,3))
(both answers are valid. You can return a list of tuples too)

문제처럼 point에 대한 목록이 있고, cloest pair처럼 가장 가까운 두 점을 구하는 문제입니다. 다만 여기에서 추가 적으로 고려해야 할 것은 아래와 같습니다. 코드의 실행시간을 선형으로 가져가야 한 다는 것 ! 

your task is to find two points with the smallest distance between them in linearithmic O(n log n) time.

쉽다고 생각했는데 문제는 속도!

from math import pow, sqrt
import pandas as pd
from itertools import combinations


def cal_dist(a, b):
    return sqrt(
        pow(a[0]-b[0], 2)+pow(a[1]-b[1], 2)
    )
    
def closest_pair(points):
    comb = combinations(points, 2)
    df = pd.DataFrame(list(comb), columns=["a", "b"])
    df["dist"] = df.apply(
        lambda x: cal_dist(x.a, x.b)
    )
    return df[
        df["dist"]==df["dist"].min()
    ][["a", "b"]].values[0].tolist()

처음엔 간단하군! 하고 pandas의 apply를 이용해서 진행했었는데, 이건 속도가 느린 이슈가 있었다. 

point의 개수가 많아질 수록 통과하지 못함!!

 

Do You Use Apply in Pandas? There is a 600x Faster Way

By leveraging vectorization and data types, you can massively speed up complex computations in Pandas

towardsdatascience.com

그래서 swifter를 사용하려고 했더니 오류가 발생! 결국에 기초적인 패키지와 알고리즘을 바탕으로 시간 최적화를 얻어내는 것이 이 문제의 포인트가 되겠다. 

Merge Sort(Divide and Conquer )를 통한 코드 실행시간 최적화

사이즈가 커지면 문제가 되는 부분이기 때문에 Divide and Conquer 알고리즘을 통해서 접근해야한다. 그리드를 반으로 쪼갠다음에, 최소 함수가 남을때 까지 재귀적으로 진행하는 것이다. 자세한 내용은 코드에 주석으로 써두었다!  최 하단의 cloesest_pair 부터 읽으면 된다.

from math import sqrt

def cal_dist(a, b):
    return sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2)

def closest_pair_rec(points_sorted_x, points_sorted_y):
    # input: x와 y 값을 기준으로 sort 
    num_points = len(points_sorted_x)
    
    if num_points <= 3:
        return min(
            (   # 거리, 점1, 점2를 return 
                cal_dist(points_sorted_x[i], points_sorted_x[j]),
                points_sorted_x[i], points_sorted_x[j]
            ) for i in range(num_points) for j in range(i + 1, num_points)
        )

    # devide 진행
    mid_idx = num_points // 2
    mid_x = points_sorted_x[mid_idx][0]
    Lx = points_sorted_x[:mid_idx]
    Rx = points_sorted_x[mid_idx:]
    midpoint = points_sorted_x[mid_idx][0]
    Ly = list(filter(lambda x: x[0] <= midpoint, points_sorted_y))
    Ry = list(filter(lambda x: x[0] > midpoint, points_sorted_y))
    
    # return은 결국 하나의 점 조합 line 10 참조
    # 아래 코드에서 결과가 나오기 전 까지 함수는 devide 과정을 계속 진행한다. 
    # 이 코드에서는 3개씩 조합을 만들어, min dist 값을 return 한다.
    # L, R 끼리 넣어야, 처음 정렬한 효과를 유지할 수 있다. 
    (d1, p1, q1) = closest_pair_rec(Lx, Ly) 
    (d2, p2, q2) = closest_pair_rec(Rx, Ry)
    
    
    d = min(d1, d2)
    if d == d1:
        pair_min = (p1, q1)
    else:
        pair_min = (p2, q2)

    
    strip = [p for p in points_sorted_y if mid_x - d < p[0] < mid_x + d]
    strip_len = len(strip)
    
    for i in range(strip_len):
        for j in range(i + 1, min(i + 8, strip_len)):
            p, q = strip[i], strip[j]
            dist = cal_dist(p, q)
            if dist < d:
                d = dist
                pair_min = (p, q)
                
    return d, pair_min[0], pair_min[1]

def closest_pair(points):
    # x와 y 값을 기준으로 sort 
    # 거리는 기본적으로 x와 y의 차의 제곱합이기 때문에 각 정렬하여 최대한 가까운 점끼리 계산될 수 있도록 함.
    points_sorted_x = sorted(points, key=lambda x: x[0])
    points_sorted_y = sorted(points, key=lambda x: x[1])
    _, p1, p2 = closest_pair_rec(points_sorted_x, points_sorted_y)
    return [p1, p2]

 

번외: 이미 구현된거 쓰는게 최고!

위에서는 직접 알고리즘을 구현했지만, 사실 이미 구현된 것들을 활용하면 코드를 훨씬 쉽게 쓸 수 있다! 

from scipy.spatial import cKDTree as KDT


def closest_pair(points):
    tree = KDT(points)
    record = None
    nn = tree.query(points, k=2)
    for i, dist in enumerate(nn[0]):
        if not record or record[0] > dist[1]:
            record = [dist[1], points[i], points[nn[1][i][1]]]
    return (record[1], tuple(record[2]))

ㄷㄷ

 

참고자료

반응형

댓글