K-Nearest Neighbor 알고리즘
-레이블(정답)이 없는 예시를 분류하기 위한 알고리즘.
-가장 고전적이고 직관적이라는 특징이 있음.
-새로운 데이터를 입력 받았을 때, 가장 가까이 있는 것이 무엇이냐를 중심으로 새로운 데이터의 종류를 정해주는 알고리즘.
분류 : 라벨이 있음, 지도 학습 / 군집화 : 라벨이 없음, 비지도 학습
NumPy / Matplotlib / Pyplot 라이브러리를 사용해 구현
1. 라이브러리 임포트하기
import numpy as np
import matplotlib.pyplot as plt
2. 데이터 셋 만들기
단맛과 아삭거림을 기준으로 데이터셋을 넣고, 입력한 데이터는 target에 넣음.
grape = [8, 5]
fish = [2, 3]
carrot = [7, 10]
orange = [7, 3]
celery = [3, 8]
cheese = [1, 1]
category = ['과일','단백질','채소','과일','채소','단백질']
dan = int(input('단맛 입력 (1~10) : '))
asac = int(input('아삭거림 입력 (1~10) : '))
target = [dan, asac]
3. 함수 만들기
데이터들을 하나의 넘파이 배열로 합침.
def data_set():
dataset = np.array([grape, fish, carrot, orange, celery, cheese]) # 분류집단
size = len(dataset)
class_target = np.tile(target, (size, 1)) #분류대상
class_category = np.array(category) #분류범주
return dataset, class_target, class_category
#dataset 생성
dataset, class_target, class_category = data_set() #data_set 함수 호출.
4. 유클리드 거리 계산식
def classify(dataset, class_target, class_category, k):
#유클리드 거리계산
diffMat = class_target - dataset # 두점의 차
sqDiffMat = diffMat**2 # 차에 대한 제곱
row_sum = sqDiffMat.sum(axis=1) # 차에 대한 제곱에 대한 합
distance = np.sqrt(row_sum) #차의 대한 제곱에 대한 합의 제곱근 (최종거리)
#가까운 거리 오름차순 정렬
sortDist = distance.argsort() #이웃한 k개 선정
class_result = {} #딕셔너리
for i in range(k):
c = class_category[sortDist[i]] #이웃한 k개 딕셔너리 안에 넣음.
class_result[c] = class_result.get(c, 0) + 1 #딕셔너리의 키에 분류 범주 넣음
return class_result
5. 함수 호출하기
k = int(input('K값 입력(1~3) : '))
class_result = classify(dataset, class_target, class_category, k) # classify() 호출
print(class_result)
6. 출력을 함수화하기
def classify_result(class_result):
protein = fruit = vegetable = 0
for c in class_result.keys():
if c == '단백질':
protein = class_result[c]
elif c =='과일' :
fruit = class_result[c]
else :
vegetable = class_result[c]
if protein > fruit and protein > vegetable:
result = "분류대상은 단백질 입니다."
elif fruit > protein and fruit > vegetable:
result = "분류대상은 과일 입니다"
else :
result = "분류대상은 채소 입니다."
return result
a = classify_result(class_result)
print(a)
7. 출력 결과
8. 시각화
#시각화 (o = 과일 / + = 단백질 / * = 채소)
plt.scatter(8, 5, marker='o')
plt.scatter(2, 3, marker='+')
plt.scatter(7, 10, marker='*')
plt.scatter(7, 3, marker='o')
plt.scatter(3, 8, marker='+')
plt.scatter(1, 1, marker='*')
plt.scatter(dan, asac, color = 'red')
plt.show()
9. 출력 결과
살..려..줘...