`
gstarwd
  • 浏览: 1488229 次
  • 性别: Icon_minigender_1
  • 来自: 杭州
社区版块
存档分类
最新评论

AI 决策树ID3 代码(c++)

阅读更多

http://blog.csdn.net/cctt_1/archive/2009/02/03/3860725.aspx

 

源代码工程文件(vs2005)http://d.download.csdn.net/down/1018461/cctt_1

过去在网上找了段代码,发现写的代码要改些地方,而且也想顺便练习下自己的c++编码。

首先我要建立一个真正的树形结构。于是使用了自己过去的GeneralTree.h (当然这里还是改动些GeneralTree的代码例如增添了些函数,另外把有些私有函数变成了公有函数)。

训练文本格式如下:并命名为decision2.txt 并发在自己的工程目录下。当然你也可以改改相关源代码

概念  颜色 形状 轻重 
苹果   红   球   一般
苹果   绿   球   一般
香蕉   黄   弯月  一般
草莓   红   球    轻
草莓   绿   球    轻
西瓜   绿   椭球  重
西瓜   绿   球    重
桔子   桔黄 椭球  轻

测试格式文本格式如下:命名为test.txt并放在工程目录下(试试改改源代码)

颜色 形状 轻重 
红   球   一般
绿   球   一般
黄   弯月  一般

这里应该考虑各个类分开的。不过为了看起来方便,就合在一起了。

下面是具体代码:

  1. /* created by chico chen  
  2. *  date 2009/02/02  
  3. *  如需转载注明出处  
  4. */   
  5. #include "stdafx.h"   
  6. #include <iostream>   
  7. #include <fstream>   
  8. #include <string>   
  9. #include <sstream>   
  10. #include <vector>   
  11. #include <map>   
  12. #include <cmath>   
  13. #include "D:\\Tools\\not Finished\\TreeTest\\TreeTest\\GeneralTree.h"   
  14. using   namespace  std;  
  15. // this class is for computing attribute entropy    
  16. class  AttribDiff  
  17. {  
  18. public :  
  19.     string attribName; // 属性名   
  20.     map<string,int >  attribNum;     //具体属性和个数对   
  21.     map<string,map<string,int >> typeNumber;   
  22.     // 第一个string为具体属性名,第二个为类型,   
  23.     // int是类型在具体属性中的个数.   
  24.     // 例如:是否可见  类型    形状   
  25.     //          1       西瓜    圆   
  26.     //          1       冬瓜    扁   
  27.     //          0       橘子    圆   
  28.     // 其中具体属性为圆,类型为西瓜等个数为形状为圆的类型为西瓜的个数   
  29.     AttribDiff(string& attribName)  
  30.     {  
  31.         this ->attribName = attribName;  
  32.     }  
  33.     // in order to computer entropy of an attribute   
  34.     double  AttribDifferComputer(vector<vector<string>> infos, int  i_attrib, int  i_types, vector< int >& visible)  
  35.     {  
  36.         double   probability = 0;  
  37.         double  entropy = 0;  
  38.         double  attribG = 0;  
  39.         map<string,int > temp;  
  40.         int  tempNum = 0;  
  41.         for ( int  i =0 ; i < infos.size(); i++)  
  42.         {  
  43.             if (visible[i] != 0 )  
  44.             {  
  45.                 tempNum = attribNum[infos[i][i_attrib]];  
  46.                 attribNum[infos[i][i_attrib]] = ++tempNum;  
  47.                 temp = typeNumber[infos[i][i_attrib]];  
  48.                 tempNum = temp[infos[i][i_types]];  
  49.                 temp[infos[i][i_types]] = ++tempNum;  
  50.                 typeNumber[infos[i][i_attrib]] = temp;  
  51.             }  
  52.         }  
  53.         map<string,int >::iterator i_number;  
  54.         map<string,int >::iterator i_type;  
  55.       
  56.         for (i_number = attribNum.begin(); i_number != attribNum.end(); i_number++)  
  57.         {  
  58.               
  59.             probability = (*i_number).second/(double )infos.size();  
  60.             cout <<(*i_number).first <<"概率为:" << probability<<endl;  
  61.             entropy = 0;  
  62.               
  63.             for (i_type = typeNumber[(*i_number).first].begin(); i_type != typeNumber[(*i_number).first].end(); i_type++)  
  64.             {  
  65.                 entropy += ComputerEntropyHelp((*i_type).second/(double )(*i_number).second);  
  66.             }  
  67.               
  68.             attribG += (-1)*probability * entropy;  
  69.               
  70.         }  
  71.           
  72.         return  attribG;  
  73.     }  
  74.     // compute the entropy   
  75.     double  ComputerEntropyHelp( double  pi)  
  76.     {  
  77.         return  pi*log(pi)/log(( double )2);  
  78.     }  
  79. };  
  80. // this class is create a data struct for general tree node   
  81. class  NodeInfo  
  82. {  
  83. public :  
  84.     // 颜色   
  85.     // 红   
  86.     // 蓝   
  87.     string attribName; //  the attribute name, such as 颜色   
  88.     vector<string> detailAttrib; // all of detail attributes under one of attribute name, for example, 红   
  89.     NodeInfo()  
  90.     {  
  91.         attribName = "" ;  
  92.     }  
  93.     NodeInfo(string & attribName)  
  94.     {  
  95.         this ->attribName = attribName;  
  96.     }  
  97.     NodeInfo(NodeInfo & ni)  
  98.     {  
  99.         this ->attribName = ni.attribName;  
  100.         this ->detailAttrib = ni.detailAttrib;  
  101.     }  
  102.     // add detail attributes in NodeInfo   
  103.     void  NodeDetailInfoAdd(string & detailA)  
  104.     {  
  105.         if (!CheckIsHas(detailA))  
  106.         {  
  107.             this ->detailAttrib.push_back(detailA);  
  108.         }  
  109.     }  
  110.     // If detail attribute is in the detailAttrib list, return true;   
  111.     // else return false;   
  112.     bool  CheckIsHas(string & name)  
  113.     {  
  114.         for ( int  i = 0; i < detailAttrib.size(); i++)  
  115.         {  
  116.             if (strcmp(name.c_str(),detailAttrib[i].c_str()) ==0)  
  117.             {  
  118.                 // the same attribute   
  119.                 return   true ;  
  120.             }  
  121.         }  
  122.         return   false ;  
  123.     }  
  124.     // this is print control for printing NodeInfo   
  125.     static   void  Print(NodeInfo& info)  
  126.     {  
  127.         cout << info.attribName<< "(" ;  
  128.           
  129.           
  130.         for ( int  i = 0; i < info.detailAttrib.size() ; i++)  
  131.         {  
  132.             cout <<  info.detailAttrib[i]<<" " ;  
  133.         }  
  134.         cout << ")\n" ;  
  135.           
  136.     }  
  137.       
  138. };  
  139. // this class is decision tree   
  140. class  DT  
  141. {  
  142. protected :  
  143.     const  string filename;  //  the data file name   
  144.     vector<vector<string>> infos; //  the array for storing information   
  145.     vector<string> attribs;  //  the array for storing the attributes   
  146.     GeneralTree<NodeInfo>gt; // the general tree for storing the decision tree   
  147.     const   int  START;  //  which  column is the start attribute, except the type column   
  148.     const   int  I_TYPE; // the column index of type   
  149.     const   int  MAX_ENTROPY;  // set an max entropy to find the minimal entropy   
  150. private :  
  151.     // to help print   
  152.     void  PrintHelp( int  helpPrint)  
  153.     {  
  154.         for ( int  i = 0; i < helpPrint; i++)  
  155.         {  
  156.             cout << ".." ;  
  157.         }  
  158.     }  
  159.     // to find the index of the attribName in attribs array   
  160.     int  Find(string& attribName,vector<string>& attribs)  
  161.     {  
  162.         for ( int  i = 0; i < attribs.size(); i++)  
  163.         {  
  164.             if (strcmp(attribName.c_str(),attribs[i].c_str()) == 0)  
  165.             {  
  166.                 // the same    
  167.                 return  i;  
  168.             }  
  169.         }  
  170.         return  -1;  
  171.     }  
  172.     // this function is used for detecting if the arithmetic is over   
  173.     bool  CheckOver(vector< int >& visible,string& type)  
  174.     {  
  175.         map<string,int > types;  
  176.         int  temp = 0;  
  177.         for ( int  i = 0; i < infos.size(); i++)  
  178.         {  
  179.             if (visible[i] != 0)  
  180.             {  
  181.                 type = infos[i][I_TYPE];  
  182.                 temp = types[infos[i][I_TYPE]];  
  183.                 if (temp == 0)  
  184.                 {  
  185.                     types[infos[i][I_TYPE]] = 1;  
  186.                 }  
  187.                 if (types.size() > 1)  
  188.                 {  
  189.                     return   false // there are more than one types   
  190.                 }  
  191.             }  
  192.         }  
  193.         return   true //  there is only one type   
  194.     }  
  195.     // to create a Decision Tree   
  196.     void  DTCreate(GeneralTreeNode<NodeInfo> *parent, vector< int > visible,vector< int > visibleA,  int  i_attrib,string& detailA,  int  helpPrint)  
  197.     {  
  198.         if (i_attrib >= START)  
  199.         {  
  200.             for ( int  i = 0; i < infos.size(); i++)  
  201.             {  
  202.                 if (strcmp(infos[i][i_attrib].c_str(),detailA.c_str()) !=0)  
  203.                 {  
  204.                     // not same detail attribute   
  205.                     visible[i] = 0;  
  206.                 }  
  207.             }  
  208.         }  
  209.         string type = "" ;  
  210.         if (CheckOver(visible,type))  
  211.         {  
  212.             // the arithmetic is over and add the type node into the general tree   
  213.             NodeInfo n(type);  
  214.             GeneralTreeNode<NodeInfo> * node = new  GeneralTreeNode<NodeInfo>(n);  
  215.             gt.Insert(node,parent);  
  216.             PrintHelp(helpPrint);  
  217.             cout << "decision type:" <<n.attribName<<endl;  
  218.             return ;  
  219.         }  
  220.       
  221.         map<string,double > attribGs;  //  this is for deciding which attrib should be used   
  222.           
  223.         for ( int  i = START; i < attribs.size(); i++)  
  224.         {  
  225.             // iterate attribs   
  226.             if (visibleA[i] != 0)  
  227.             {  
  228.                 AttribDiff ad(attribs[i]);  
  229.                 attribGs[attribs[i]] = ad.AttribDifferComputer(infos,i,I_TYPE,visible);  
  230.                 cout <<attribs[i] <<"的G为:" << attribGs[attribs[i]]<<endl;  
  231.             }  
  232.         }  
  233.         //  to find the decision attribute   
  234.         double  min = MAX_ENTROPY;   
  235.         string attributeName;  
  236.         for (map<string, double >::iterator i = attribGs.begin(); i != attribGs.end(); i++)  
  237.         {  
  238.               
  239.             if (min >= (*i).second)  
  240.             {  
  241.                 attributeName = (*i).first;  
  242.                 min = (*i).second;  
  243.             }  
  244.         }  
  245.         NodeInfo n(attributeName);  
  246.         int  i_max = Find(attributeName,attribs);  
  247.         for ( int  i = 0; i<infos.size() ; i++)  
  248.         {  
  249.             n.NodeDetailInfoAdd(infos[i][i_max]);  
  250.         }  
  251.         GeneralTreeNode<NodeInfo> * node = new  GeneralTreeNode<NodeInfo>(n);  
  252.         gt.Insert(node,parent);  
  253.         visibleA[i_max] = 0;  
  254.         PrintHelp(helpPrint);  
  255.         cout << "choose attribute:" << attributeName<<endl;  
  256.         for ( int  i = 0; i < node->data.detailAttrib.size(); i++)  
  257.         {  
  258.             PrintHelp(helpPrint);  
  259.             cout << "go into the branch:" <<node->data.detailAttrib[i]<<endl;  
  260.             // go to every branch to decision   
  261.             DTCreate(node,visible,visibleA,i_max,node->data.detailAttrib[i],helpPrint+1);  
  262.         }  
  263.           
  264.     }  
  265. public :       
  266.     // 要注意的一点是这里的decision2.txt要放在工程目录下。当然如果你愿意可以写绝对路径   
  267.     // 注意文件的格式:   
  268.     // 首先一列为类别,然后是各个属性   
  269.     // 例如:  类型    形状   
  270.     //         西瓜    圆   
  271.     //         冬瓜    扁   
  272.     //         橘子    圆   
  273.     DT():filename("decision2.txt" ),START(1),I_TYPE(0),MAX_ENTROPY(10000)  
  274.     {  
  275.         GetInfo(attribs,infos,filename);  
  276.         DTCreate();  
  277.           
  278.     }  
  279.        
  280.     // this function is used for read data from the file   
  281.     // and create the attribute array and all information array   
  282.     // post: attribs has at least one element   
  283.     //       infos has at least one element   
  284.     // pre: filename is not empty and the file is exist   
  285.     void  GetInfo(vector<string>& attribs,vector<vector<string>>& infos, const  string& filename)  
  286.     {  
  287.         ifstream read(filename.c_str());  
  288.           
  289.         int  start = 0;  
  290.         int  end = 0;  
  291.         string info = "" ;  
  292.         getline(read,info);  
  293.         istringstream iss(info);  
  294.         string attrib;  
  295.           
  296.         while (iss >> attrib)  
  297.         {  
  298.             attribs.push_back(attrib);  
  299.         }  
  300.         while ( true )  
  301.         {  
  302.             info = "" ;  
  303.             getline(read,info);  
  304.             if (info ==  ""  || info.length() <= 1)  
  305.             {  
  306.                 break ;  
  307.             }  
  308.             vector<string> infoline;  
  309.             istringstream stream(info);  
  310.               
  311.             while (stream >> attrib)  
  312.             {  
  313.                 infoline.push_back(attrib);  
  314.             }  
  315.             infos.push_back(infoline);  
  316.         }  
  317.         read.close();  
  318.     }  
  319.     // create the DT   
  320.     void  DTCreate()  
  321.     {  
  322.         vector<int > visible(infos.size(),1);  
  323.         vector<int > visibleA(attribs.size(),1);  //  to judge which attribute is useless   
  324.         string temp = "" ;  
  325.         DTCreate(NULL,visible,visibleA,START-1,temp,0);  
  326.     }  
  327.     // print the DT   
  328.     void  Print()  
  329.     {  
  330.           
  331.         gt.LevelPrint(NodeInfo::Print);  
  332.     }  
  333.     void  Judge( const  string& testFilename,vector<string>& types, const  string& testResultFileName)  
  334.     {  
  335.         vector<string> attribs_test;  
  336.         vector<vector<string>> infos_test;  
  337.         GetInfo(attribs_test,infos_test,testFilename);  
  338.           
  339.         if (!CheckFileFormat(attribs_test))  
  340.         {  
  341.             throw   "file format error" ;  
  342.         }  
  343.         GeneralTreeNode<NodeInfo> * root = gt.GetRoot();  
  344.         for ( int  i = 0; i < infos_test.size(); i++)  
  345.         {  
  346.               
  347.             types.push_back(JudgeType(root,infos_test[i],attribs_test));  
  348.         }  
  349.         WriteTestTypesInfo(testResultFileName,types);  
  350.     }  
  351.     void  WriteTestTypesInfo( const  string& filename, vector<string>& types)  
  352.     {  
  353.         ofstream out(filename.c_str());  
  354.         out << "类别" <<endl;  
  355.         for ( int  i = 0 ; i < types.size(); i++)  
  356.         {  
  357.             out << types[i]<<endl;  
  358.         }  
  359.         out.close();  
  360.     }  
  361.     string JudgeType(GeneralTreeNode<NodeInfo> * node, vector<string>& info,vector<string>& attribs_test)  
  362.     {  
  363.         if (gt.GetChildNodeNum(node) == 0)  
  364.         {  
  365.             return  node->getData().attribName;  
  366.         }  
  367.         int  index = Find(node->getData().attribName,attribs_test);  
  368.         int  branch_index = Find(info[index],node->getData().detailAttrib);  
  369.         if (branch_index == -1)  
  370.         {  
  371.             // is not find this detail attribute in this node detailAttrib   
  372.             // there are two way to deal with this situation   
  373.             // 1. every branch has possibility to choose   
  374.             // 2. no such type and can not judge   
  375.             // the first solution make the correct ratio low   
  376.             // the second solution has no fault-tolerance.   
  377.             // and here I choose the second solution.   
  378.             // if I have more free time later, I will write the first solution   
  379.             throw   "no such type" ;  
  380.         }  
  381.         GeneralTreeNode<NodeInfo> * childNode = gt.GetAChild(node,branch_index);  
  382.         return  JudgeType(childNode, info,attribs_test);  
  383.     }  
  384.     bool  CheckFileFormat(vector<string>& attribs_test)  
  385.     {  
  386.         bool  isCorrect =  true ;  
  387.         for ( int  j = 0; j < attribs_test.size(); j++)  
  388.         {  
  389.             if (Find(attribs_test[j],attribs) == -1)  
  390.             {  
  391.                 isCorrect = false ;  
  392.             }  
  393.         }  
  394.         if (attribs_test.size() == attribs.size() - 1)  
  395.         {  
  396.             isCorrect = isCorrect && true ;  
  397.         }  
  398.         else   
  399.         {  
  400.             isCorrect = false ;  
  401.         }  
  402.         return  isCorrect;  
  403.     }  
  404. };  

这里的main函数这样写(自己使用的VS2005):

  1. int  _tmain( int  argc, _TCHAR* argv[])  
  2. {  
  3.     DT dt;  
  4.     //dt.Print();   
  5.     string  testFile =  "test.txt" ;  
  6.     string  testResult =  "testResult.txt" ;  
  7.     vector<string >types;  
  8.     dt.Judge(testFile,types,testResult);  
  9.     return  0;  
  10. }  

自己感觉DT 的注释比较详细,所以在我的blog中就不再做太多的解释。另外这段代码会将测试结果放在工程目录下的testResult.txt中。

另外在控制台上会有生成决策树ID3的相关相关的信息显示,例如:

红概率为:0.25
黄概率为:0.125
桔黄概率为:0.125
绿概率为:0.5
颜色的G为:1
球概率为:0.625
椭球概率为:0.25
弯月概率为:0.125
形状的G为:1.20121
轻概率为:0.375
一般概率为:0.375
重概率为:0.25
轻重的G为:0.688722
choose attribute:轻重
go into the branch:一般
红概率为:0.125
黄概率为:0.125
绿概率为:0.125
颜色的G为:0
球概率为:0.25
弯月概率为:0.125
形状的G为:0
..choose attribute:颜色
..go into the branch:红
....decision type:苹果
..go into the branch:绿
....decision type:苹果
..go into the branch:黄
....decision type:香蕉
..go into the branch:桔黄
....decision type:
go into the branch:轻
红概率为:0.125
桔黄概率为:0.125
绿概率为:0.125
颜色的G为:0
球概率为:0.25
椭球概率为:0.125
形状的G为:0
..choose attribute:颜色
..go into the branch:红
....decision type:草莓
..go into the branch:绿
....decision type:草莓
..go into the branch:黄
....decision type:
..go into the branch:桔黄
....decision type:桔子
go into the branch:重
..decision type:西瓜

这一段信息是什么意思呢?

红概率为:0.25
黄概率为:0.125
桔黄概率为:0.125
绿概率为:0.5
颜色的G为:1

红,黄,桔黄,绿的概率是颜色的具体属性。这里没有把entropy打印出来。如果此段代码被中科院的师弟师妹有幸看到,

你 们可以在AttribDifferComputer()函数中添加几行代码就可以把每一个entropy打印出来。反正老师也会让你们看代码,这里就当作 作业题吧。(另外老师第十章机器学习ppt上的决策树的这个例子计算结果有错误。如果你认真计算过的话)颜色G的含义是颜色G的决策值,决策值越小,选择 此属性的概率就越大。


那决策树是什么样子的呢?

choose attribute:轻重
go into the branch:一般

..choose attribute:颜色
..go into the branch:红

......................

看看上面的这些.这里代表根节点是“轻重”,然后进入“一般”分支,然后进入“一般”分支的节点为颜色..然后进入”红“分支.这里一定要注意”..“,相等的"..”代表树的相同的层次。


做出这个Decision Tree 的ID3代码主要是为了学弟学妹们在考试中测试用的。因为我只是测试了老师ppt中的例子,不保证对于所有的例子都正确。而且老师出的考试题比较变态(属性十个左右)..如果手工计算应该需要一个小时左右的时间。

当初后悔没有先编一个程序。祝各位考试顺利..(我想我这段代码可能会在考试之前被搜到)。


同时提醒大家一点, ID3也不是什么很好的算法。当两个属性的G值一致时,如果它并不能给出一个更好的判断标准。而且如果采用顺序选择很有可能生成一个非最小决策树。这点还值得研究一下。

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics