加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
KNN.py 935 Bytes
一键复制 编辑 原始数据 按行查看 历史
叶新尔 提交于 2021-12-01 13:35 . 图像分类
import numpy as np
#knn分类函数,输入训练样本,训练样本标签,测试样本(为ndarray),k值(k默认为3);返回预测标签
def KNN(traindata,trainlabels,testdata,k=3):
#计算距离,存放在dist中,每一行表示测试样本与所有训练样本的距离
num_train=traindata.shape[0]
num_test=testdata.shape[0]
dist=np.zeros((num_test,num_train))
for i in range(num_test):
dist[i]=np.reshape(np.sqrt(np.sum(np.square(testdata[i]-traindata),axis=1)),[1,num_train])
#找到每一行中从小到大排序后前k个值所对应的原来的索引,对应的标签给close_k
#统计这些索引中出现次数最多的那个数为预测样本类别
predictlabels=np.zeros((num_test,1))
for i in range(num_test):
close_k=trainlabels[np.argsort(dist[i])[:k]]
predictlabels[i]=np.argmax(np.bincount(close_k))
return predictlabels
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化