`

K最近邻(KNN)算法原理和java实现

阅读更多

原理部分:

请参考:KNN演算法

 

 

代码实现:

 

KNN结点类,用来存储最近邻的k个元组相关的信息

/**
 * KNN结点类,用来存储最近邻的k个元组相关的信息
 */
public class KNNNode {
	private int index; 			// 元组标号
	private double distance; 	// 与测试元组的距离
	private String c; 			// 所属类别
	public KNNNode(int index, double distance, String c) {
		super();
		this.index = index;
		this.distance = distance;
		this.c = c;
	}
	
	
	public int getIndex() {
		return index;
	}
	public void setIndex(int index) {
		this.index = index;
	}
	public double getDistance() {
		return distance;
	}
	public void setDistance(double distance) {
		this.distance = distance;
	}
	public String getC() {
		return c;
	}
	public void setC(String c) {
		this.c = c;
	}
}

 

 

KNN算法主体类

/**
 * KNN算法主体类
 */
public class KNN {
	/**
	 * 设置优先级队列的比较函数,距离越大,优先级越高
	 */
	private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
		public int compare(KNNNode o1, KNNNode o2) {
			if (o1.getDistance() >= o2.getDistance()) {
				return 1;
			} else {
				return 0;
			}
		}
	};
	/**
	 * 获取K个不同的随机数
	 * @param k 随机数的个数
	 * @param max 随机数最大的范围
	 * @return 生成的随机数数组
	 */
	public List<Integer> getRandKNum(int k, int max) {
		List<Integer> rand = new ArrayList<Integer>(k);
		for (int i = 0; i < k; i++) {
			int temp = (int) (Math.random() * max);
			if (!rand.contains(temp)) {
				rand.add(temp);
			} else {
				i--;
			}
		}
		return rand;
	}
	/**
	 * 计算测试元组与训练元组之前的距离
	 * @param d1 测试元组
	 * @param d2 训练元组
	 * @return 距离值
	 */
	public double calDistance(List<Double> d1, List<Double> d2) {
		double distance = 0.00;
		for (int i = 0; i < d1.size(); i++) {
			distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
		}
		return distance;
	}
	/**
	 * 执行KNN算法,获取测试元组的类别
	 * @param datas 训练数据集
	 * @param testData 测试元组
	 * @param k 设定的K值
	 * @return 测试元组的类别
	 */
	public String knn(List<List<Double>> datas, List<Double> testData, int k) {
		PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);
		List<Integer> randNum = getRandKNum(k, datas.size());
		for (int i = 0; i < k; i++) {
			int index = randNum.get(i);
			List<Double> currData = datas.get(index);
			String c = currData.get(currData.size() - 1).toString();
			KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
			pq.add(node);
		}
		for (int i = 0; i < datas.size(); i++) {
			List<Double> t = datas.get(i);
			double distance = calDistance(testData, t);
			KNNNode top = pq.peek();
			if (top.getDistance() > distance) {
				pq.remove();
				pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
			}
		}
		
		return getMostClass(pq);
	}
	/**
	 * 获取所得到的k个最近邻元组的多数类
	 * @param pq 存储k个最近近邻元组的优先级队列
	 * @return 多数类的名称
	 */
	private String getMostClass(PriorityQueue<KNNNode> pq) {
		Map<String, Integer> classCount = new HashMap<String, Integer>();
		for (int i = 0; i < pq.size(); i++) {
			KNNNode node = pq.remove();
			String c = node.getC();
			if (classCount.containsKey(c)) {
				classCount.put(c, classCount.get(c) + 1);
			} else {
				classCount.put(c, 1);
			}
		}
		int maxIndex = -1;
		int maxCount = 0;
		Object[] classes = classCount.keySet().toArray();
		for (int i = 0; i < classes.length; i++) {
			if (classCount.get(classes[i]) > maxCount) {
				maxIndex = i;
				maxCount = classCount.get(classes[i]);
			}
		}
		return classes[maxIndex].toString();
	}
}

 

KNN算法测试类

/**
 * KNN算法测试类
 */
public class TestKNN {
	
	/**
	 * 从数据文件中读取数据
	 * @param datas 存储数据的集合对象
	 * @param path 数据文件的路径
	 */
	public void read(List<List<Double>> datas, String path){
		try {
			BufferedReader br = new BufferedReader(new FileReader(new File(path)));
			String data = br.readLine();
			List<Double> l = null;
			while (data != null) {
				String t[] = data.split(" ");
				l = new ArrayList<Double>();
				for (int i = 0; i < t.length; i++) {
					l.add(Double.parseDouble(t[i]));
				}
				datas.add(l);
				data = br.readLine();
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	/**
	 * 程序执行入口
	 * @param args
	 */
	public static void main(String[] args) {
		TestKNN t = new TestKNN();
		String datafile = new File("").getAbsolutePath() + File.separator + "datafile";
		String testfile = new File("").getAbsolutePath() + File.separator + "testfile";
		try {
			List<List<Double>> datas = new ArrayList<List<Double>>();
			List<List<Double>> testDatas = new ArrayList<List<Double>>();
			t.read(datas, datafile);
			t.read(testDatas, testfile);
			KNN knn = new KNN();
			for (int i = 0; i < testDatas.size(); i++) {
				List<Double> test = testDatas.get(i);
				System.out.print("测试元组: ");
				for (int j = 0; j < test.size(); j++) {
					System.out.print(test.get(j) + " ");
				}
				System.out.print("类别为: ");
				System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
}

 

  • KNN.rar (9.4 KB)
  • 下载次数: 117
分享到:
评论

相关推荐

    KNN算法原理和java实现.doc

    KNN算法原理和java实现,K最近邻分类

    KNN最近邻算法java实现

    KNN最近邻算法java实现

    Java实现kNN算法

    Java实现kNN算法

    knn.rar_KNN java_KNN算法 java_knn算法_knn算法java实现

    KNN算法的实现过程,用java实现,很详细的哦

    KNN算法实验报告【Java实现】.doc

    KNN算法实验报告【Java实现】.doc

    KNN算法java实现

    该KNN算法示例采用java实现,对数据分类算法的学习很有用,而且代码封装很好,简单易懂,极适合初学者

    python可视化实现KNN算法

    KNN–最近邻分类算法,算法逻辑比较简单,思路如下: 1.设一待分类数据iData,先计算其到已标记数据集中每个数据的距离,例如欧拉距离sqrt((x1-x2)^2+(y1-y2)^2); 2.然后根据离iData最近的k个数据的分类,出现次数...

    KNN算法的Java实现

    邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。 kNN算法的核心思想是...

    KNN算法JVAA实现

    KNN算法JVAA实现邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法可以说是整个数据挖掘分类技术中最简单的方法了。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用她最接近的k个邻居来代表...

    机器学习实战 - k近邻算法(KNN算法)总结

    K-近邻算法,又称为 KNN 算法,是数据挖掘技术中原理最简单的算法。 KNN 的工作原理:给定一个已知类别标签的数据训练集,输入没有标签的新数据后,在训练数据集中找到与新数据最临近的 K 个实例。如果这 K 个实例的...

    KNN算法的matlab实现

    模式识别中的KNN算法实现,基于Matlab的实现,以及剪辑近邻法的matlab实现。

    mnist手写体的识别采用KNN算法,Java实现

    mnist手写体的识别采用KNN算法,Java实现,60K训练集,10K测试集。代码主要包括读mnist数据集,KNN算法。

    论文研究-K最近邻算法理论与应用综述.pdf

    k最近邻算法(kNN)是一个十分简单的分类算法,该算法包括两个步骤:(1)在给定的搜索训练集上按一定距离度量,寻找一个k的值。(2)在这个kNN算法当中,根据大多数分为一致的类来进行分类。kNN算法具有的非参数...

    KNN算法实现分类问题(JAVA)实现

    这个是我实验课的作业,Java实现knn算法,对网上需手动输入数据的算法进行了一些改进,注释详细,数据是文件夹中的txt文件,读者可以自己更换成自己的数据。

    KNN算法的MATLAB实现

    邻近算法,或者说K最近邻(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据...

    机器学习算法KNN(K近邻)应用实例——实现对是否患糖尿病的预测

    资源中包括完成的KNN算法训练和实现过程,以及用于机器学习的糖尿病数据集。 数据特征包括: Pregnancies:怀孕次数 Glucose:葡萄糖测试值 BloodPressure:血压 SkinThickness:表皮厚度 Insulin:胰岛素 BMI:身体...

    KNN.rar_java k-nn_knn算法_weka k_weka knn

    KNN算法用JAVA实现,在WEKA平台上实现

    贝叶斯与KNN算法实现

    2. 设计基于最近邻准则的分类器。 资源包括代码实现和课程报告--Bayes和KNN分类器实现鸢尾花数据集分类 源码实现包括手撕贝叶斯和KNN以及使用工具包实现 课程报告主要包括以下部分: 一、 问题描述 二、 数据预处理 ...

    java实现的KNN算法

    java实现的KNN算法,对学习KNN算法有帮助

    KNN.rar_K._cutxs1_featherszxg_k-最近邻_利用numpy库实现KNN算法

    KNN算法,K最近邻分类算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

Global site tag (gtag.js) - Google Analytics