气有浩然 学无止境

KNN算法的实现与使用

什么是KNN

KNN即k-近邻算法,它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

算法目标


图上数据来自鸢尾花数据集,为了展示方便,只选取数据2个维度的特征,接下来动手实现KNN并对红色点进行分类。

代码实现

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from collections import Counter

def KNN_classify(train_x, train_y, input_x, k):
    assert train_y.shape[0] > k > 0, 'got a wrong k'

    #计算待预测数据和所有样本数据的欧氏距离
    distance_list = [np.sqrt(np.sum((x - input_x) ** 2)) for x in train_x]
    # for x in train_x:
    #     dis = np.sqrt(np.sum((x - inputs) ** 2))
    #     distance_list.append(dis)
    distance_k = np.argsort(distance_list)[:k]
    predict_labels = train_y[distance_k]
    predict = Counter(predict_labels).most_common(1)[0][0]
    return predict

if __name__ == '__main__':

    iris = datasets.load_iris()
    x = iris.data
    y = iris.target
    x = x[:, 1:3]
    inputs = np.array([2.5, 2.5])

    # 画个图直观感受下
    plt.scatter(x[y == 0, 0], x[y == 0, 1], color='blue')
    plt.scatter(x[y == 1, 0], x[y == 1, 1], color='green')
    plt.scatter(inputs[0], inputs[1], color='red')
    plt.show()
    predict = KNN_classify(x, y, inputs, 8)
    print(predict)

⬆️