论坛首页 综合技术论坛

生成文本聚类java实现 (2)

浏览 2939 次
精华帖 (0) :: 良好帖 (0) :: 新手帖 (0) :: 隐藏帖 (0)
作者 正文
   发表时间:2011-04-12   最后修改:2011-04-12

   呵呵,继续。

 本节的学习内容:

 

4.从剩余的词中提取文本特征,即最能代表文本的词  
5.用空间向量表示文本,空间向量需标准化,即将数值映射到-1到1之间  
6.利用所获取的空间向量进行聚类分析  
7.交叉验证  

第四步,提取文本特征

 本文使用KNN算法和SVM算法学习提取文本特征的思想。

 研究最终目的。
 训练材料:

语料 分类
腐化  "生活作风"  "女色"  "情妇"   "权色"   "生活糜烂"   "生活堕落" 生活作风

"东城" "西城" "崇文" "宣武" "朝阳" "海淀" "丰台" "石景山" "房山" "通州" "顺义" "大兴" "昌平

" "平谷" "怀柔" "门头沟" "密云" "延庆"

北京
上访  信访  举报  揭发  揭露  "买官"  "卖官" 上访举报
李刚  "河大撞人" 撞人 我爸是李刚
"送钱短信" OR ( 驾校 AND 交警 ) 送钱短信

"乐东县" "保亭县" "陵水县" "琼中县" "白沙县" "昌江县" "屯昌县" "定安县" "澄迈县" "临高县"

"儋州" "东方" "五指山" "万宁" "琼海" "文昌" "三亚" "海口"

海南

 

 训练结果就是跟上面语料和分类的有极高的相似度。

 

下面是基本的KNN算法。KNN.java

package com.antbee.cluster.knn;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

/**
 * @author KNN算法主体类
 * @version 创建时间:2011-4-2 下午03:47:28
 * 类说明
 */
public class KNN {
	/** 
     * 设置优先级队列的比较函数,距离越大,优先级越高 
     */  
    private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {  
        public int compare(KNNNode o1, KNNNode o2) {  
            if (o1.getDistance() >= o2.getDistance()) {  
                return 1;  
            } else {  
                return 0;  
            }  
        }  
    };  
    /** 
     * 获取K个不同的随机数 
     * @param k 随机数的个数 
     * @param max 随机数最大的范围 
     * @return 生成的随机数数组 
     */  
    public List<Integer> getRandKNum(int k, int max) {  
        List<Integer> rand = new ArrayList<Integer>(k);  
        for (int i = 0; i < k; i++) {  
            int temp = (int) (Math.random() * max);  
            if (!rand.contains(temp)) {  
                rand.add(temp);  
            } else {  
                i--;  
            }  
        }  
        return rand;  
    }  
    /** 
     * 计算测试元组与训练元组之前的距离 
     * @param d1 测试元组 
     * @param d2 训练元组 
     * @return 距离值 
     */  
    public double calDistance(List<Double> d1, List<Double> d2) {  
        double distance = 0.00;  
        for (int i = 0; i < d1.size(); i++) {  
            distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));  
        }  
        return distance;  
    }  
    /** 
     * 执行KNN算法,获取测试元组的类别 
     * @param datas 训练数据集 
     * @param testData 测试元组 
     * @param k 设定的K值 
     * @return 测试元组的类别 
     */  
    public String knn(List<List<Double>> datas, List<Double> testData, int k) {  
        PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);  
        List<Integer> randNum = getRandKNum(k, datas.size());  
        for (int i = 0; i < k; i++) {  
            int index = randNum.get(i);  
            List<Double> currData = datas.get(index);  
            String c = currData.get(currData.size() - 1).toString();  
            KNNNode node = new KNNNode(index, calDistance(testData, currData), c);  
            pq.add(node);  
        }  
        for (int i = 0; i < datas.size(); i++) {  
            List<Double> t = datas.get(i);  
            double distance = calDistance(testData, t);  
            KNNNode top = pq.peek();  
            if (top.getDistance() > distance) {  
                pq.remove();  
                pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));  
            }  
        }  
          
        return getMostClass(pq);  
    }  
    /** 
     * 获取所得到的k个最近邻元组的多数类 
     * @param pq 存储k个最近近邻元组的优先级队列 
     * @return 多数类的名称 
     */  
    private String getMostClass(PriorityQueue<KNNNode> pq) {  
        Map<String, Integer> classCount = new HashMap<String, Integer>();  
        for (int i = 0; i < pq.size(); i++) {  
            KNNNode node = pq.remove();  
            String c = node.getC();  
            if (classCount.containsKey(c)) {  
                classCount.put(c, classCount.get(c) + 1);  
            } else {  
                classCount.put(c, 1);  
            }  
        }  
        int maxIndex = -1;  
        int maxCount = 0;  
        Object[] classes = classCount.keySet().toArray();  
        for (int i = 0; i < classes.length; i++) {  
            if (classCount.get(classes[i]) > maxCount) {  
                maxIndex = i;  
                maxCount = classCount.get(classes[i]);  
            }  
        }  
        return classes[maxIndex].toString();  
    }  
}

   KNNNode.java 结点类

package com.antbee.cluster.knn;
/**
 * @author KNN结点类,用来存储最近邻的k个元组相关的信息 
 * @version 创建时间:2011-4-2 下午03:43:39
 * 类说明
 */
public class KNNNode {
	private int index; // 元组标号  
    private double distance; // 与测试元组的距离  
    private String c; // 所属类别  
    public KNNNode(int index, double distance, String c) {  
        super();  
        this.index = index;  
        this.distance = distance;  
        this.c = c;  
    }  
      
      
    public int getIndex() {  
        return index;  
    }  
    public void setIndex(int index) {  
        this.index = index;  
    }  
    public double getDistance() {  
        return distance;  
    }  
    public void setDistance(double distance) {  
        this.distance = distance;  
    }  
    public String getC() {  
        return c;  
    }  
    public void setC(String c) {  
        this.c = c;  
    }  
}

 TestKNN.java 测试类

package com.antbee.cluster.knn;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;

import org.junit.Test;

/**
 * @author Weiya He E-mail:heweiya@gmail.com
 * @version 创建时间:2011-4-2 下午03:49:04
 * 类说明
 */
public class TestKNN {
	/** 
     * 从数据文件中读取数据 
     * @param datas 存储数据的集合对象 
     * @param path 数据文件的路径 
     */  
    public void read(List<List<Double>> datas, String path){  
        try {  
            BufferedReader br = new BufferedReader(new FileReader(new File(path)));  
            String data = br.readLine();  
            List<Double> l = null;  
            while (data != null) {  
                String t[] = data.split(" ");  
                l = new ArrayList<Double>();  
                for (int i = 0; i < t.length; i++) {  
                    l.add(Double.parseDouble(t[i]));  
                }  
                datas.add(l);  
                data = br.readLine();  
            }  
        } catch (Exception e) {  
            e.printStackTrace();  
        }  
    }  
      
    /** 
     * 程序执行入口 
     * @param args 
     */  
    @Test
    public void test() {  
        TestKNN t = new TestKNN();  
        String datafile = this.getClass().getClassLoader().getResource("datafile.txt").toString();
        datafile = datafile.replace("file:/", "");//windows 环境上要做的一步
        String testfile = this.getClass().getClassLoader().getResource("testfile.txt").toString();
        testfile = testfile.replace("file:/", "");//windows 环境上要做的一步
        try {  
            List<List<Double>> datas = new ArrayList<List<Double>>();  
            List<List<Double>> testDatas = new ArrayList<List<Double>>();  
            t.read(datas, datafile);  
            t.read(testDatas, testfile);  
            KNN knn = new KNN();  
            for (int i = 0; i < testDatas.size(); i++) {  
                List<Double> test = testDatas.get(i);  
                System.out.print("测试元组: ");  
                for (int j = 0; j < test.size(); j++) {  
                    System.out.print(test.get(j) + " ");  
                }  
                System.out.print("类别为: ");  
                System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 2)))));  
            }  
        } catch (Exception e) {  
            e.printStackTrace();  
        }  
    }  
}

  datafile.txt文件内容:

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1  
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1  
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1  
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0  
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1  
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0

  testfile.txt文件内容:

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5  
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8  
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2  
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5  
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5  
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5

 最终的运行结果:

测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0
测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1
测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0

 下面的工作就是如何让汉字也成为如上的Long类型的数字呢,我们现在使用词频的空间向量来代替这些文字。

 

   发表时间:2011-05-12  
楼主 这个系列文章怎么没有更新了,大家都等着看呢?楼主 加油呀
0 请登录后投票
   发表时间:2011-06-02  
多谢楼主, 正在研究分类聚类,学习了!
0 请登录后投票
论坛首页 综合技术版

跳转论坛:
Global site tag (gtag.js) - Google Analytics