`

生成文本聚类java实现 (2)

 
阅读更多
Java代码  收藏代码
  1. 4.从剩余的词中提取文本特征,即最能代表文本的词    
  2. 5.用空间向量表示文本,空间向量需标准化,即将数值映射到-11之间    
  3. 6.利用所获取的空间向量进行聚类分析    
  4. 7.交叉验证    

第四步,提取文本特征

 本文使用KNN算法和SVM算法学习提取文本特征的思想。

 研究最终目的。
 训练材料:

语料 分类
腐化  "生活作风"  "女色"  "情妇"   "权色"   "生活糜烂"   "生活堕落" 生活作风

"东城" "西城" "崇文" "宣武" "朝阳" "海淀" "丰台" "石景山" "房山" "通州" "顺义" "大兴" "昌平

" "平谷" "怀柔" "门头沟" "密云" "延庆"

北京
上访  信访  举报  揭发  揭露  "买官"  "卖官" 上访举报
李刚  "河大撞人" 撞人 我爸是李刚
"送钱短信" OR ( 驾校 AND 交警 ) 送钱短信

"乐东县" "保亭县" "陵水县" "琼中县" "白沙县" "昌江县" "屯昌县" "定安县" "澄迈县" "临高县"

"儋州" "东方" "五指山" "万宁" "琼海" "文昌" "三亚" "海口"

海南

 

 训练结果就是跟上面语料和分类的有极高的相似度。

 

下面是基本的KNN算法。KNN.java

Java代码  收藏代码
  1. package com.antbee.cluster.knn;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Comparator;  
  5. import java.util.HashMap;  
  6. import java.util.List;  
  7. import java.util.Map;  
  8. import java.util.PriorityQueue;  
  9.   
  10. /** 
  11.  * @author KNN算法主体类 
  12.  * @version 创建时间:2011-4-2 下午03:47:28 
  13.  * 类说明 
  14.  */  
  15. public class KNN {  
  16.     /**  
  17.      * 设置优先级队列的比较函数,距离越大,优先级越高  
  18.      */    
  19.     private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {    
  20.         public int compare(KNNNode o1, KNNNode o2) {    
  21.             if (o1.getDistance() >= o2.getDistance()) {    
  22.                 return 1;    
  23.             } else {    
  24.                 return 0;    
  25.             }    
  26.         }    
  27.     };    
  28.     /**  
  29.      * 获取K个不同的随机数  
  30.      * @param k 随机数的个数  
  31.      * @param max 随机数最大的范围  
  32.      * @return 生成的随机数数组  
  33.      */    
  34.     public List<Integer> getRandKNum(int k, int max) {    
  35.         List<Integer> rand = new ArrayList<Integer>(k);    
  36.         for (int i = 0; i < k; i++) {    
  37.             int temp = (int) (Math.random() * max);    
  38.             if (!rand.contains(temp)) {    
  39.                 rand.add(temp);    
  40.             } else {    
  41.                 i--;    
  42.             }    
  43.         }    
  44.         return rand;    
  45.     }    
  46.     /**  
  47.      * 计算测试元组与训练元组之前的距离  
  48.      * @param d1 测试元组  
  49.      * @param d2 训练元组  
  50.      * @return 距离值  
  51.      */    
  52.     public double calDistance(List<Double> d1, List<Double> d2) {    
  53.         double distance = 0.00;    
  54.         for (int i = 0; i < d1.size(); i++) {    
  55.             distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));    
  56.         }    
  57.         return distance;    
  58.     }    
  59.     /**  
  60.      * 执行KNN算法,获取测试元组的类别  
  61.      * @param datas 训练数据集  
  62.      * @param testData 测试元组  
  63.      * @param k 设定的K值  
  64.      * @return 测试元组的类别  
  65.      */    
  66.     public String knn(List<List<Double>> datas, List<Double> testData, int k) {    
  67.         PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);    
  68.         List<Integer> randNum = getRandKNum(k, datas.size());    
  69.         for (int i = 0; i < k; i++) {    
  70.             int index = randNum.get(i);    
  71.             List<Double> currData = datas.get(index);    
  72.             String c = currData.get(currData.size() - 1).toString();    
  73.             KNNNode node = new KNNNode(index, calDistance(testData, currData), c);    
  74.             pq.add(node);    
  75.         }    
  76.         for (int i = 0; i < datas.size(); i++) {    
  77.             List<Double> t = datas.get(i);    
  78.             double distance = calDistance(testData, t);    
  79.             KNNNode top = pq.peek();    
  80.             if (top.getDistance() > distance) {    
  81.                 pq.remove();    
  82.                 pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));    
  83.             }    
  84.         }    
  85.             
  86.         return getMostClass(pq);    
  87.     }    
  88.     /**  
  89.      * 获取所得到的k个最近邻元组的多数类  
  90.      * @param pq 存储k个最近近邻元组的优先级队列  
  91.      * @return 多数类的名称  
  92.      */    
  93.     private String getMostClass(PriorityQueue<KNNNode> pq) {    
  94.         Map<String, Integer> classCount = new HashMap<String, Integer>();    
  95.         for (int i = 0; i < pq.size(); i++) {    
  96.             KNNNode node = pq.remove();    
  97.             String c = node.getC();    
  98.             if (classCount.containsKey(c)) {    
  99.                 classCount.put(c, classCount.get(c) + 1);    
  100.             } else {    
  101.                 classCount.put(c, 1);    
  102.             }    
  103.         }    
  104.         int maxIndex = -1;    
  105.         int maxCount = 0;    
  106.         Object[] classes = classCount.keySet().toArray();    
  107.         for (int i = 0; i < classes.length; i++) {    
  108.             if (classCount.get(classes[i]) > maxCount) {    
  109.                 maxIndex = i;    
  110.                 maxCount = classCount.get(classes[i]);    
  111.             }    
  112.         }    
  113.         return classes[maxIndex].toString();    
  114.     }    
  115. }  

   KNNNode.java 结点类

Java代码  收藏代码
  1. package com.antbee.cluster.knn;  
  2. /** 
  3.  * @author KNN结点类,用来存储最近邻的k个元组相关的信息  
  4.  * @version 创建时间:2011-4-2 下午03:43:39 
  5.  * 类说明 
  6.  */  
  7. public class KNNNode {  
  8.     private int index; // 元组标号    
  9.     private double distance; // 与测试元组的距离    
  10.     private String c; // 所属类别    
  11.     public KNNNode(int index, double distance, String c) {    
  12.         super();    
  13.         this.index = index;    
  14.         this.distance = distance;    
  15.         this.c = c;    
  16.     }    
  17.         
  18.         
  19.     public int getIndex() {    
  20.         return index;    
  21.     }    
  22.     public void setIndex(int index) {    
  23.         this.index = index;    
  24.     }    
  25.     public double getDistance() {    
  26.         return distance;    
  27.     }    
  28.     public void setDistance(double distance) {    
  29.         this.distance = distance;    
  30.     }    
  31.     public String getC() {    
  32.         return c;    
  33.     }    
  34.     public void setC(String c) {    
  35.         this.c = c;    
  36.     }    
  37. }  

 TestKNN.java 测试类

Java代码  收藏代码
  1. package com.antbee.cluster.knn;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.util.ArrayList;  
  7. import java.util.List;  
  8.   
  9. import org.junit.Test;  
  10.   
  11. /** 
  12.  * @author Weiya He E-mail:heweiya@gmail.com 
  13.  * @version 创建时间:2011-4-2 下午03:49:04 
  14.  * 类说明 
  15.  */  
  16. public class TestKNN {  
  17.     /**  
  18.      * 从数据文件中读取数据  
  19.      * @param datas 存储数据的集合对象  
  20.      * @param path 数据文件的路径  
  21.      */    
  22.     public void read(List<List<Double>> datas, String path){    
  23.         try {    
  24.             BufferedReader br = new BufferedReader(new FileReader(new File(path)));    
  25.             String data = br.readLine();    
  26.             List<Double> l = null;    
  27.             while (data != null) {    
  28.                 String t[] = data.split(" ");    
  29.                 l = new ArrayList<Double>();    
  30.                 for (int i = 0; i < t.length; i++) {    
  31.                     l.add(Double.parseDouble(t[i]));    
  32.                 }    
  33.                 datas.add(l);    
  34.                 data = br.readLine();    
  35.             }    
  36.         } catch (Exception e) {    
  37.             e.printStackTrace();    
  38.         }    
  39.     }    
  40.         
  41.     /**  
  42.      * 程序执行入口  
  43.      * @param args  
  44.      */    
  45.     @Test  
  46.     public void test() {    
  47.         TestKNN t = new TestKNN();    
  48.         String datafile = this.getClass().getClassLoader().getResource("datafile.txt").toString();  
  49.         datafile = datafile.replace("file:/""");//windows 环境上要做的一步  
  50.         String testfile = this.getClass().getClassLoader().getResource("testfile.txt").toString();  
  51.         testfile = testfile.replace("file:/""");//windows 环境上要做的一步  
  52.         try {    
  53.             List<List<Double>> datas = new ArrayList<List<Double>>();    
  54.             List<List<Double>> testDatas = new ArrayList<List<Double>>();    
  55.             t.read(datas, datafile);    
  56.             t.read(testDatas, testfile);    
  57.             KNN knn = new KNN();    
  58.             for (int i = 0; i < testDatas.size(); i++) {    
  59.                 List<Double> test = testDatas.get(i);    
  60.                 System.out.print("测试元组: ");    
  61.                 for (int j = 0; j < test.size(); j++) {    
  62.                     System.out.print(test.get(j) + " ");    
  63.                 }    
  64.                 System.out.print("类别为: ");    
  65.                 System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 2)))));    
  66.             }    
  67.         } catch (Exception e) {    
  68.             e.printStackTrace();    
  69.         }    
  70.     }    
  71. }  

  datafile.txt文件内容:

Java代码  收藏代码
  1. 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1    
  2. 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1    
  3. 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1    
  4. 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0    
  5. 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1    
  6. 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0  

  testfile.txt文件内容:

Java代码  收藏代码
  1. 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5    
  2. 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8    
  3. 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2    
  4. 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5    
  5. 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5    
  6. 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5  

 最终的运行结果:

Java代码  收藏代码
  1. 测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1  
  2. 测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1  
  3. 测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1  
  4. 测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0  
  5. 测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1  
  6. 测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0  

 下面的工作就是如何让汉字也成为如上的Long类型的数字呢,我们现在使用词频的空间向量来代替这些文字。

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics