`

Java实现的朴素贝叶斯分类器

阅读更多

目前的算法只能处理结果只有两种的情况,即true or false. 多分枝或者是数字类型的还无法处理。

用到的一些基础数据结构可以参考上一篇关于ID3的代码。 

 

这里只贴出来实现贝叶斯分类预测的部分:

package classifier;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import util.ArffUtil;


/**
 * NBC means Naive Bayes Classifier
 * @author wenjun_yang
 *
 */
public class NBCUtil {
	
	ArffUtil util = new ArffUtil();
	private List<String> attributeList = null;
	private List<String[]> dataList = null;
	private String decAttributeName = null;
	private int decAttributeIndex = -1;
	
	private Map<String, List<String[]>> seperatedDataTable = null;
	public NBCUtil(String decAttributeName, List<String> attributeList, List<String[]> dataList) {
		this.attributeList = attributeList;
		this.dataList = dataList;
		this.decAttributeName = decAttributeName;
		
		this.decAttributeIndex = util.getValueIndex(decAttributeName, this.attributeList);
		this.seperatedDataTable = seperateDataList(dataList);
	}
	
	private Map<String, List<String[]>> seperateDataList(List<String[]> dataList) {
		Map<String, List<String[]>> map = new HashMap<String, List<String[]>>();
		
		for(String[] arr : dataList) {
			if(decAttributeIndex >= 0 && decAttributeIndex < arr.length) {
				String currentKey = arr[decAttributeIndex]; 
				if(map.containsKey(currentKey)) {
					List<String[]> tempList = map.get(currentKey);
					tempList.add(arr);
					map.put(currentKey, tempList);
				} else {
					List<String[]> tempList = new ArrayList<String[]>();
					tempList.add(arr);
					map.put(currentKey , tempList);
				}
			}
		}
		
		return map;
	}
	
	public Boolean predict(Map<String, String> predictData, String targetDecAttributeValue) {
		if(predictData.containsKey(decAttributeName)) predictData.remove(decAttributeName);
		
		List<String[]> positiveDataTable = new ArrayList<String[]>();
		if(seperatedDataTable.containsKey(targetDecAttributeValue)) {
			positiveDataTable = seperatedDataTable.get(targetDecAttributeValue);
		}
		
		double resultP = 1.;
		
		// Step 1: 逐个属性的比率进行计算
		// 即: 计算 P(Attr=Value|Y=true) / P(Attr=Value|Y=false) 的值
		for(String attrName : predictData.keySet()) {
			String attrValue = predictData.get(attrName);
			int attrIndex = util.getValueIndex(attrName, attributeList);
			int attrPositiveCount = 0;
			int attrNegativeCount = 0;
			
			for(String[] arr : dataList) {
				if(arr[attrIndex].equals(attrValue)) {
					if(arr[decAttributeIndex].equals(targetDecAttributeValue)) {
						attrPositiveCount++;
					} else {
						attrNegativeCount++;
					}
				}
			}
			double temp =  (attrPositiveCount / (double)positiveDataTable.size() ) /
							(attrNegativeCount / (double)(dataList.size() - positiveDataTable.size()));
			resultP *= temp;
		}
		// 最后计算 P(Y=true) / P(Y=false)
		resultP *= positiveDataTable.size() / (double)(dataList.size() - positiveDataTable.size());
		System.out.println(resultP);
		if(resultP > 1) {
			return true;
		} else {
			return false;
		}
	}
}

 

 

完整的项目也上传了,可以直接使用。

数据源来自weka

1
0
分享到:
评论
3 楼 酷呀嗒 2015-04-16  
给的文件没有main函数 这个怎么启动啊
2 楼 RangerWolf 2015-01-11  
caihongshijie6 写道
你好,这里面用到的数据文件在哪里找呢?


Weka SciKit-learn 都自带不少数据
另外网上也有不少公开的数据
1 楼 caihongshijie6 2015-01-03  
你好,这里面用到的数据文件在哪里找呢?

相关推荐

Global site tag (gtag.js) - Google Analytics