mahout的trainnb调用的是TrainNaiveBayesJob完成训练模型任务。所在包:
org.apache.mahout.classifier.naivebayes.training
TrainNaiveBayesJob的输入是在tfidf文件上split出来的一部分,用作训练。
TrainNaiveBayesJob代码分析,
首先加入一些命令行选项,如
LABEL -L ALPHA_I -a LABEL_INDEX -li TRAIN_COMPLEMENTARY -c
然后从输入文件中读取label,将label保存于label index,例如20news group的例子,读取的label有两个,label index如下
Key class: class org.apache.hadoop.io.Text Value Class: class org.apache.hadoop.io.IntWritable Key: 20news-bydate-test: Value: 0 Key: 20news-bydate-train: Value: 1
其实也就是将分类建一个索引。
接下来,将相同label的vectors相加。也就是将同一个类别的所有的文章的vector相加。这里vector其实是一个key/value vector,每项由词的id和tfidf值组成。这样相加后就是一个一个类的vector,相同id的tfidf相加,没有的则插入,类似两个递增的链表的合并。由一个job来完成:
1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
// Key class: class org.apache.hadoop.io.Text
// Value Class: class org.apache.mahout.math.VectorWritable //add up all the vectors with the same labels, while mapping the labels into our index Job indexInstances = prepareJob(getInputPath(), //input path getTempPath(SUMMED_OBSERVATIONS), //output path SequenceFileInputFormat.class, //input format IndexInstancesMapper.class, //mapper class IntWritable.class, //mapper key VectorWritable.class, //mapper value VectorSumReducer.class, //reducer class IntWritable.class, //reducer key VectorWritable.class, //reducer value SequenceFileOutputFormat.class); //output format indexInstances.setCombinerClass(VectorSumReducer.class); boolean succeeded = indexInstances.waitForCompletion(true); if (!succeeded) { return -1; } |
Mapper为IndexInstancesMapper,Reducer为Reducer VectorSumReducer,代码也比较简单,如下,
1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
String label = labelText.toString().split("/")[1]; if (labelIndex.containsKey(label)) { //从文件中读取的类的index作为key ctx.write(new IntWritable(labelIndex.get(label)), instance); } else { ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1); } } //相同key的vector相加 protected void reduce(WritableComparable< ? > key, Iterable< VectorWritable > values, Context ctx) throws IOException, InterruptedException { Vector vector = null; for (VectorWritable v : values) { if (vector == null) { vector = v.get(); } else { vector.assign(v.get(), Functions.PLUS); } } ctx.write(key, new VectorWritable(vector)); } |
OK,到现在已经得到了< label_index,label_vector >,即类的id和类中所有item(或者说feature)的TFIDF值。此步得到类似如下的输出,
Key: 0 Value: /comp.sys.ibm.pc.hardware/60252:{93562:17.52922821044922,93559:9.745443344116211,93558:107.53932094573975,93557:49.015570640563965,93556:9.745443344116211……} key:1 Value: /alt.atheism/53261:{93562:26.293842315673828,93560:19.490886688232422,93559:9.745443344116211,93558:78.52010536193848,93557:62.2713, 93555:14.35555171……}
下一个阶段就是统计每个label的所有ITIDF和,输入为上一步的输出,并由一个job来执行,
1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
//sum up all the weights from the previous step, per label and per feature
Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(WEIGHTS), SequenceFileInputFormat.class, WeightsMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class); weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize)); weightSummer.setCombinerClass(VectorSumReducer.class); succeeded = weightSummer.waitForCompletion(true); if (!succeeded) { return -1; } |
job的mapper为WeightsMapper,reducer与上一步的相同,为VectorSumReducer。
mapper如下,
1
2 3 4 5 6 7 8 9 |
protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException {
Vector instance = value.get(); if (weightsPerFeature == null) { weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements()); } int label = index.get(); weightsPerFeature.assign(instance, Functions.PLUS); weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum()); } |
此步的输出写在cleanup()中。
1
2 3 4 5 6 7 8 9 |
protected void cleanup(Context ctx) throws IOException, InterruptedException {
if (weightsPerFeature != null) { ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature)); ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel)); } super.cleanup(ctx); } |
也就是说输出只有两个key/value.
一个是WEIGHTS_PER_FEATURE(定义的常量,__SPF)
一个是WEIGHTS_PER_LABEL(__SPL)
weightsPerFeature其实就是保持上一步的vector没变,仍然是一个类中所有iterm(feature)的TFIDF。
weightsPerLabel就是求每个label中的和了。
可以看到输出为,
Key: __SPF Value: {93562:43.82307052612305,93560:19.490886688232422,93559:19.490886688232422,93558:186.05942630767822,93557:111.28696632385254,93556:9.745443344116211……} Key: __SPL Value: {1:7085520.472989678,0:4662610.912284017}
最后一步,先看源代码,
1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
//calculate the Thetas, write out to LABEL_THETA_NORMALIZER vectors
//-- TODO: add reference here to the part of the Rennie paper that discusses this Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(THETAS), SequenceFileInputFormat.class, ThetaMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class); thetaSummer.setCombinerClass(VectorSumReducer.class); thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI); thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary); /* TODO(robinanil): Enable this when thetanormalization works. succeeded = thetaSummer.waitForCompletion(true); if (!succeeded) { return -1; }*/ |
可以看到thetaSummer.waitForCompletion(true)被注释掉了,job没有执行。注释里面说的Rennie paper指的就是mahout bayes算法参考的这篇论文:Tackling the Poor Assumptions of Naive Bayes Text Classifiers,论文里面有个求Ɵ的公式如下。不知为何注释掉?求解。
最最后一步,其实model有weightsPerFeature和weightsPerLabel就完成了。这一步也就是把它们变成矩阵形式,如下,每行一个权重vector。
____|item1,iterm2,item3……
lab1|
lab2|
……
源代码如下,
1
2 3 4 5 |
//得到SparseMatrix矩阵
NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf()); naiveBayesModel.validate(); //序列化,写到output/naiveBayesModel.bin naiveBayesModel.serialize(getOutputPath(), getConf()); |
THE END
http://hnote.org/big-data/mahout/mahout-train-naive-bayes-job
相关推荐
mahoutAlgorithms源码分析 mahout代码解析
Mahout是一个Java的机器学习库。Mahout的完整源代码,基于maven,可以轻易导入工程中
Mahout教程内含源码以及说明书可以自己运行复现.zip
mahout,朴素贝叶斯分类,中文分词,mahout,朴素贝叶斯分类,中文分词,
mahout0.9的源码,支持hadoop2,需要自行使用mvn编译。mvn编译使用命令: mvn clean install -Dhadoop2 -Dhadoop.2.version=2.2.0 -DskipTests
Mahout:整体框架,实现了协同过滤 Deeplearning4j,构建VSM Jieba:分词,关键词提取 HanLP:分词,关键词提取 Spring Boot:提供API、ORM 关键实现 基于用户的协同过滤 直接调用Mahout相关接口即可 选择不同...
mahout-distribution-0.5-src.zip mahout 源码包
Mahout in Action 源码,结合Mahout in Action 学习数据挖掘,比较容易理解
该资源是mahout in action 中的源码,适用于自学,可在github下载:https://github.com/tdunning/MiA
mahout 数据挖掘 数据分析 开源 hadoop
mahout实战 源码 mahout实战 配套 mahout-distribution-0.5.tar.gz 版本
【甘道夫】通过Mahout构建贝叶斯文本分类器案例详解 -- 配套源码
Thank you for requesting the download for Apache Mahout Cookbook. Please click the following link to download the code:
mahout0.11版本,源码,可修改源码并自己编译,使用java语言编写,maven编译
mahout 0.7 src, mahout 源码包, hadoop 机器学习子项目 mahout 源码包
svd算法的工具类,直接调用出结果,调用及设置方式参考http://blog.csdn.net/fansy1990 <mahout源码分析之DistributedLanczosSolver(七)>
Mahout 是 Apache Software Foundation(ASF) 旗下的一个开源项目,提供一些可扩展的机器学习领域经典算法的实现,旨在帮助开发人员更加方便快捷地创建智能应用程序。Mahout包含许多实现,包括聚类、分类、推荐过滤...
maven_mahout_template-mahout-0.8
mahout_help,mahout的java api帮助文档,可以帮你更轻松掌握mahout