`
zhoupinheng
  • 浏览: 34209 次
  • 性别: Icon_minigender_1
  • 来自: 深圳
社区版块
存档分类
最新评论

java实现K近邻算法

阅读更多
package knn;

import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;

public class KNN {

	public String getKnnValue(double[][] value, String[] tags, int k, double[] verifyValue) {

		String tag = null;
		if (value == null || tags == null || verifyValue == null || k < 1) {
			// para error
		} else {
			if (value.length == tags.length && value[0].length == verifyValue.length) {
				PriorityQueue<Item> queue = new PriorityQueue<Item>(k + 1);
				for (int i = 0; i < value.length; i++) {
					queue.add(new Item(getOuDistance(value[i], verifyValue), tags[i]));
					if (queue.size() > k) {
						queue.poll();
					}
				}

				Map<String, Item> map = new HashMap<String, Item>();
				for (Item item : queue) {
					Item countItem = null;
					if (!map.containsKey(item.strValue)) {
						map.put(item.strValue, new Item(0, item.strValue));
					}
					countItem = map.get(item.strValue);
					countItem.numValue = countItem.numValue + 1.0;
				}

				queue.clear();
				queue.addAll(map.values());

				tag = queue.poll().strValue;
			}
		}

		return tag;
	}

	private double getOuDistance(double[] fs, double[] validateValue) {
		double value = 0;
		for (int i = 0; i < fs.length; i++) {
			value += Math.pow((fs[i] - validateValue[i]), 2.0);
		}
		return Math.sqrt(value);
	}

	private class Item implements Comparable<Item> {
		public Item(double num, String str) {
			this.numValue = num;
			this.strValue = str;
		}

		public double numValue = 0;
		public String strValue;

		@Override
		public int compareTo(Item o) {
			if (numValue > o.numValue) {
				return -1;
			} else if (numValue < o.numValue) {
				return 1;
			} else {
				return 0;
			}
		}
	}

}

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics