手撸 K近邻 算法
可用的库为 NumPy
、Matplotlib
,不能使用 sklearn
库,然后将该算法用于鸢尾花问题分类。
要求使用三种距离函数:欧氏距离
、 曼哈顿距离
与 夹角余弦
。
要求对 K 设置三种不同数值,对比分析三种不同距离的不同K值何种效果最好。
K近邻算法视频介绍:https://www.bilibili.com/video/BV1vA4y1o72F/?p=1
摘要
KNN算法是一种基于实例的监督学习算法,用于解决分类和回归问题。在KNN算法中,给定一个未知样本,首先找到与该样本最相似的K个训练样本,然后根据这K个训练样本的类别信息,进行分类或回归预测。
KNN算法的核心思想是“近朱者赤,近墨者黑”。也就是说,与未知样本相似的训练样本,往往具有相同的类别信息。因此,在KNN算法中,找到与未知样本最相似的K个训练样本,可以有效地预测该未知样本的类别或数值。
概述
KNN算法的具体实现过程包括以下几个步骤:
- 选择距离度 量方法,如欧氏距离、曼哈顿距离等。
- 给定一个未知样本,计算该样本与训练集中每个样本之间的距离。
- 根据距离从小到大对训练集中的样本进行排序,选取与未知样本最近的K个样本。
- 根据K个样本的类别信息,进行分类或回归预测。对于分类问题,可以采用多数投票等方法决定样本的类别;对于回归问题,可以采用平均值等方法预测样本的数值。
在本实验中,应该没有涉及到回归问题。
我的工作
为了尽量减少对第三方库的依赖,在本实验中,我仅使用 Matplotlib
库来绘制图表。
需要导入的库如下:
from dataclasses import dataclass
from math import sqrt
import csv
from typing import Callable
import matplotlib.pyplot as plt
下载鸢尾花数据集
数据集下载地址:https://archive.ics.uci.edu/dataset/53/iris
下载后将压缩包中的 iris.data
文件放到工作目录下。
该数据集文件中包含150个样本,每个样本包含4个特征和1个 类别信息。其中,4个特征分别为花萼长度、花萼宽度、花瓣长度和花瓣宽度,类别信息包含3种鸢尾花的类别,分别为山鸢尾、变色鸢尾和维吉尼亚鸢尾。
读取数据集
该数据集为标准的 csv 格式文件,使用 python 内置的 csv
库解析即可。
def load_iris(data_file_path: str) -> list[Iris]:
"""加载鸢尾花数据集"""
with open(data_file_path, newline='') as f:
iris_list: list[Iris] = []
for row in csv.reader(f):
if len(row) != 5:
continue
iris_list.append(Iris(float(row[0]), float(row[1]), float(row[2]), float(row[3]), row[4]))
return iris_list
在读取数据集时,我将每个样本的特征和类别信息封装到了 Iris
类中,方便后续的处理。
@dataclass
class Iris:
"""鸢尾花"""
sepal_length: float # 花萼长度
sepal_width: float # 花萼宽度
petal_length: float # 花瓣长度
petal_width: float # 花瓣宽度
clazz: str # 类别
定义距离算法
在KNN算法中,距离度量方法是一个重要的参数,它决定了样本之间的相似度。常用的距离度量方法包括欧氏距离、曼哈顿距离、夹角余弦等。
欧氏距离
def euclidean_distance(x1: list[float], x2: list[float]) -> float:
"""欧氏距离"""
distance = []
for i in range(len(x1)):
distance.append((x1[i] - x2[i]) ** 2)
return sqrt(sum(distance))
曼哈顿距离
def manhattan_distance(x1: list[float], x2: list[float]) -> float:
"""曼哈顿距离"""
distance = []
for i in range(len(x1)):
distance.append(abs(x1[i] - x2[i]))
return sum(distance)
夹角余弦
def cosine_distance(x1: list[float], x2: list[float]) -> float:
"""余弦距离"""
dot_product = sum([x1[i] * x2[i] for i in range(len(x1))])
mod_x1 = sqrt(sum([x ** 2 for x in x1]))
mod_x2 = sqrt(sum([x ** 2 for x in x2]))
return 1 - dot_product / (mod_x1 * mod_x2)
封装距离算法
为了使得这些距离算法适配该实验的数据集,我将距离算法函数进行进一步的封装,使得函数的参数为两个 Iris
对象,返回值为两个样本之间的距离。
def distance_for_iris(ori_func: Callable[[list[float], list[float]], float]) -> Callable[[Iris, Iris], float]:
"""针对鸢尾花的距离函数封装"""
def distance(iris1: Iris, iris2: Iris) -> float:
return ori_func(
[iris1.sepal_length, iris1.sepal_width, iris1.petal_length, iris1.petal_width],
[iris2.sepal_length, iris2.sepal_width, iris2.petal_length, iris2.petal_width]
)
return distance