K-최근접 이웃(KNN) 알고리즘이란? - 2. 파이썬 밑바닥 코딩
이번 편에서는 유클리드 거리부터 시작해서 최근접 이웃 알고리즘을 밑바닥부터 파이썬으로 구현하도록 하겠다.
간단한 준비 운동으로 2차원의 두 점의 좌표를 이용해 유클리드 거리를 구해보도록 하자. 루트는 math 패키지에서 불러오도록 한다. 또한 두 좌표가 리스트(list)라는 점을 이용해 인덱스 이용해 리스트의 원소들을 불러서 사용하도록 한다.
from math import sqrt
점1 = [1, 3]
점2 = [3, 4]
유클리드_거리 = sqrt((점1[0] - 점2[0])**2 + (점1[1] -점2[1])**2)
print(유클리드_거리)
본격적인 코드에서는 해당 작업을 Numpy 패키지의 행렬 연산 함수들을 이용해 좀 더 단순하게 만들도록 하겠다.
일단 다음 코드 블록에 나열되어 있는 패키지들을 부르고, 딕셔너리 자료구조를 이용해 새로운 데이터를 생성하도록 한다. 해당 데이터에는 k, r 두 가지 클래스가 있고 이 두 가지 클래스로 분류되는 6개의 점들은 두 개의 리스트로 이루어진 리스트로 나누어져 있는 것을 볼 수 있을 것이다. 여기에 새로 관측된 데이터 [5, 7]이 등장했다고 해보자.
import numpy as np
from math import sqrt
import matplotlib.pyplot as plt
from matplotlib import style
from collections import Counter
style.use('fivethirtyeight')
데이터셋 = {'k': [[1,2],[2,3],[3,1]], 'r': [[6,5],[7,6],[8,6]]}
새_데이터 = [5,7]
시각적으로 우리가 무엇을 만들었는지 이해하기 위해서 데이터가 어떻게 분포되어 있는지 직접 보도록 하자. 다음 코드 블록을 이용하면 matplotlib 패키지를 이용해서 데이터 분포를 볼 수 있을 것이다. 독자가 알아볼 수 있도록 줄을 나눠 쓴 코드도 함께 포함했다. 첫 번째 for 반복문은 딕셔너리의 키(key) 리스트를 불러오고 두 번째 for 반복문은 해당 키를 이용해 불러온 리스트의 원소인 리스트를 1 불러온다. 2
다음으로 데이터 한 개의 좌표 정보에 해당하는 리스트를 scatter 함수를 이용해 좌표 평면에 더해준다. scatter 함수의 매개변수(parameter) s는 점 크기를 결정하고 color는 각 점에 해당하는 클래스에 맞춰서 색깔을 입힌다. 마지막으로 좌표 평면을 보고자 한다면 plt.show()를 써준다. 그럼 코드 블록 아래와 같은 그래프가 등장할 것이다.
for i in 데이터셋:
for ii in 데이터셋[i]:
plt.scatter(ii[0], ii[1], s=100 ,color = i)
plt.show()
#Pythonic 표기
[plt.scatter[ii[0], ii[1], s=100 ,color = i) for ii in 데이터셋[i]] for i in 데이터셋]
plt.show()
※여기서 우리는 한가지 꼼수를 사용했는데, 바로 딕셔너리의 키들이 scatter 함수의 color 매개변수의 색깔 k(검은색), r(빨간색)에 해당되는 알파벳을 사용한 것이다. 만약 다른 문자를 사용한다면 해당 문자에 해당되는 색이 없어 바로 오류가 뜰 수 있으므로 주의.
여기에 해당 코드를 plt.show()가 있는 줄 전에 추가해주면 다음과 같이 점이 새로 추가된 그래프를 볼 수 있다.
plt.scatter(새_데이터[0], 새_데이터[1], s=50, c = 'b')
이제 KNN 알고리즘에 해당하는 함수를 쓰도록 하겠다.
def k최근접이웃(데이터, 예상, k=3):
if len(데이터) >= k:
warnings.warn('k 매개변수가 전체 투표 그룹보다 적습니다!')
거리모음 = []
for 그룹 in 데이터:
for 피쳐 in 데이터[그룹]:
#유클리드거리 = np.sqrt(np.sum((np.array(피쳐) - np.array(예상))**2))
유클리드거리 = np.linalg.norm(np.array(피쳐) - np.array(예상))
거리모음.append([유클리드거리, 그룹])
투표 = [i[1] for i in sorted(거리모음)[:k]]
투표결과 = Counter(투표).most_common(1)[0][0]
return 투표결과
k최근접이웃(데이터셋, 새_데이터, 3)
일단 k최근접이웃 함수를 정의하고 매개변수들을 정한다. 매개변수는 순서대로 KNN 알고리즘을 훈련시킬 데이터 딕셔너리, 우리가 분류해야 되는 새로운 데이터 리스트, 그리고 입력이 없으면 기본으로 셋인 투표에 참여하는 가장 가까운 이웃 k 개다.
def k최근접이웃(데이터, 예상, k=3):
기본적으로 우리는 최소한 k가 데이터에 포함되어 있는 클래스 수보다 크길 원할 것이다. 만약 클래스가 3개인데 k가 2라면 우리는 3 클래스 전부에게 투표에 참여할 기회를 박탈시키게 될 것이다. 따라서 유저가 이 사실을 알 수 있도록 경고 메시지를 띄우도록 한다.
if len(데이터) >= k:
warnings.warn('k 매개변수가 전체 투표 그룹보다 적습니다!')
그 후 앞서 데이터를 좌표 평면을 그렸을 때처럼 for 반복문을 이용해 새 데이터와 훈련용 데이터 점들 사이의 유클리드 거리를 '거리모음' 리스트에 저장해준다. 여기서 우리는 전 편에서 언급했던 거리를 저장해야 되기 때문에 훈련용 데이터가 커질수록 메모리가 많이 잡아먹힐 수도 있는 가능성을 볼 수 있다. 다만 이는 단순히 모든 유클리드 거리를 계산하는 것 대신 우리가 원하는 새 데이터에서 비롯되는 기준 구의 원지름 이상의 거리를 가진 점들을 전부 배제하는 것으로 해결할 수도 있다.
유클리드 거리는 앞서 언급했던 첫 번째 코드 블록에서 사용했던 것을 Numpy 패키지를 이용해 행렬 째로 바로 계산할 수 있다. 이에 해당하는 코멘트로 달아놓은 코드에서 예를 들자면 다음과 같은 작업이 일어난다.
-
[1, 2] - [5, 7] = [-4, -5]
-
[-4, -5] → [16, 25]
-
16+25 = 41
-
41 → √(41)
그러나 Numpy 패키지에는 이것보다 계산속도가 빠른 함수가 존재하는데, 그게 바로 노름(norm)을 계산해주는 함수다. 우리는 [1, 2] - [5, 7]를 한 개의 벡터로 취급해 그 벡터의 노름 혹은 벡터 거리를 계산하여 유클리드 거리를 도출해낼 수 있다.
마지막으로 유클리드 거리를 '거리모음' 리스트에 넣어준다.
거리모음 = []
for 그룹 in 데이터:
for 피쳐 in 데이터[그룹]:
#유클리드거리 = np.sqrt(np.sum((np.array(피쳐) - np.array(예상))**2))
유클리드거리 = np.linalg.norm(np.array(피쳐) - np.array(예상))
거리모음.append([유클리드거리, 그룹])
이제 가장 가까운 k 개의 점들이 투표하는 과정을 정의해준다.
투표 = [i[1] for i in sorted(거리모음)[:k]]
투표결과 = Counter(투표).most_common(1)[0][0]
return 투표결과
일단 sorted 함수를 이용해 거리가 작은 순으로 '거리모음' 리스트를 정렬해준다. sorted는 '거리모음'의 원소인 리스트의 첫 번째 원소인 유클리드 거리를 이용해 '거리모음' 리스트를 정렬해 줄 것이다. 그리고 정렬된 '거리모음' 리스트에서 인덱스 슬라이싱을 통해 k개의 가장 가까운 이웃들의 투표를 추출해낸다. 이제 투표 리스트는 가장 가까운 k개의 점들의 '그룹'값을 가지고 있을 것이다.
다음으로 Counter 함수를 이용해 투표를 세준다. Counter 함수의 most_common(n) 속성은 n가지 제일 많은 원소들에 대한 정보가 담긴 투플로 이루어진 리스트를 주는데 우리의 경우에는 [('r', 3)]을 줄 것이다. 이는 앞서 우리가 만들었던 좌표평면 그래프에서 세 개의 붉은 점들이 가장 가깝다는 점에서 예상할 수 있는 결과다. 우리가 원하는 것은 'r' 값이고 이는 해당 리스트의 첫번째 투플의 첫번째 원소이므로 [0][0] 인덱스로 추출해낼 수 있다.
이제 우리는 밑바닥부터 KNN 알고리즘을 만들어냈고, 해당 알고리즘이 어떻게 작동하는지 배웠다. 다음 편에서는 실제 데이터에 적용해보고 scikit learn 내장 KNN 알고리즘과 성능 비교를 해보도록 하겠다.
[Copyright ⓒ 블로그채널 무단전재 및 재배포 금지]