fork(1) download
  1. /**
  2.  * 功能: ID3算法
  3.  * 语言: C++
  4.  * 作者: 刘永康
  5.  * 版本: 1.0
  6.  * 时间: 2012.12.3
  7.  */
  8.  
  9. #include <iostream>
  10. #include <string>
  11. #include <vector>
  12. #include <map>
  13. #include <utility>
  14. #include <cmath>
  15.  
  16.  
  17. int attribute_quantity; //属性数量
  18. int target_attribute_num; //目标属性
  19. std::vector<std::string> attributes; //所有属性
  20. int example_quantity; //训练样例的数量
  21. std::vector<std::vector<std::string> > examples; //所有训练样例
  22. const std::string yes("yes");
  23. const std::string no("no");
  24.  
  25. struct node {
  26. std::string attribute; //该结点的属性
  27. std::string decision_attribute; //决定该结点的属性
  28. std::vector<node> child; //该结点的孩子
  29. };
  30.  
  31.  
  32. /**
  33.  * 寻找最佳分类属性,也是ID3算法的核心部分
  34.  * cur_examples存放所要训练样例的索引值
  35.  * attributes_num存放待分类属性的索引值
  36.  */
  37. int FindBestAttribute(std::vector<int> &cur_examples, std::vector<int> &attributes_num)
  38. {
  39. int positive, negative, n;
  40. double entropy;
  41.  
  42. //统计正样例和负样例的数量
  43. positive = 0;
  44. negative = 0;
  45. n = cur_examples.size();
  46. for (int i = 0; i != n; ++i) {
  47. if (examples[cur_examples[i]][4].compare(yes)) {
  48. ++negative;
  49. } else {
  50. ++positive;
  51. }
  52. }
  53.  
  54. //训练样例的熵值
  55. entropy = -(positive * 1.0 / n) * log(positive * 1.0 / n) / log(2) - (negative * 1.0 / n) * log(negative * 1.0 / n) / log(2);
  56.  
  57. std::vector<double> gain(attributes_num.size(), entropy);
  58.  
  59. for (int i = 0; i != attributes_num.size(); ++i) {
  60. int cur_positive, cur_negative, cur_n;
  61. double cur_entropy;
  62. std::map<std::string, int> values;
  63.  
  64. //找出该属性的所有取值
  65. for (int j = 0; j != n; ++j) {
  66. std::map<std::string, int>::iterator iter = values.find(examples[cur_examples[j]][attributes_num[i]]);
  67. if (iter != values.end()) {
  68. iter->second++;
  69. } else {
  70. values.insert(std::make_pair(examples[cur_examples[j]][attributes_num[i]], 1));
  71. }
  72. }
  73.  
  74. //计算该属性的信息增益
  75. for (std::map<std::string, int>::iterator iter = values.begin(); iter != values.end(); ++iter) {
  76. //统计该取值下正样例和负样例的数量
  77. cur_positive = 0;
  78. cur_negative = 0;
  79. cur_n = iter->second;
  80. for (int j = 0; j != n; ++j) {
  81. if (examples[cur_examples[j]][attributes_num[i]] == iter->first) {
  82. if (examples[cur_examples[j]][4].compare(yes)) {
  83. ++cur_negative;
  84. } else {
  85. ++cur_positive;
  86. }
  87. }
  88. }
  89.  
  90. //该取值的熵值
  91. if (cur_positive && cur_negative) {
  92. cur_entropy = -(cur_positive * 1.0 / cur_n) * log(cur_positive * 1.0 / cur_n) / log(2) - (cur_negative * 1.0 / cur_n) * log(cur_negative * 1.0 / cur_n) / log(2);
  93. } else {
  94. cur_entropy = 0;
  95. }
  96.  
  97. //信息增益的计算方法
  98. gain[i] -= cur_entropy * cur_n / n;
  99. }
  100. }
  101.  
  102. //寻找信息增益最大的属性,即为最佳分类属性
  103. int max;
  104. max = 0;
  105. for (int i = 0; i != attributes_num.size(); ++i) {
  106. if (gain[i] > gain[max]) {
  107. max = i;
  108. }
  109. }
  110. return max;
  111. }
  112.  
  113.  
  114. /**
  115.  * 构建决策树
  116.  * root为要建立的结点
  117.  * cur_examples存放所要训练样例的索引值
  118.  * attributes_num存放待分类属性的索引值
  119.  */
  120. void DecisionTreeBuild(node &root, std::vector<int> &cur_examples, std::vector<int> &attributes_num)
  121. {
  122. int positive, negative;
  123. int best_attribute;
  124.  
  125. //统计正样例和负样例的数量
  126. positive = 0;
  127. negative = 0;
  128. for (int i = 0; i != cur_examples.size(); ++i) {
  129. if (examples[cur_examples[i]][4].compare(yes)) {
  130. ++negative;
  131. } else {
  132. ++positive;
  133. }
  134. }
  135.  
  136. //如果examples都为正,该结点取值为yes
  137. if (positive == cur_examples.size()) {
  138. root.attribute = yes;
  139. return;
  140. }
  141.  
  142. //如果examples都为负,该结点取值为no
  143. if (negative == cur_examples.size()) {
  144. root.attribute = no;
  145. return;
  146. }
  147.  
  148. //如果attributes_num为空,那么选examples中最多的target_attribute
  149. if (attributes_num.size() == 0) {
  150. if (positive > negative) {
  151. root.attribute = yes;
  152. } else {
  153. root.attribute = no;
  154. }
  155. return ;
  156. }
  157.  
  158. //寻找最佳分类属性
  159. best_attribute = FindBestAttribute(cur_examples, attributes_num);
  160. root.attribute = attributes[attributes_num[best_attribute]];
  161.  
  162. //找出最佳分类属性的所有取值
  163. std::map<std::string, int> values;
  164. for (int i = 0; i != cur_examples.size(); ++i) {
  165. std::map<std::string, int>::iterator iter = values.find(examples[cur_examples[i]][attributes_num[best_attribute]]);
  166. if (iter != values.end()) {
  167. iter->second++;
  168. } else {
  169. values.insert(std::make_pair(examples[cur_examples[i]][attributes_num[best_attribute]], 1));
  170. }
  171. }
  172.  
  173. //对最佳分类属性的每一个值进行递归分类
  174. for (std::map<std::string, int>::iterator iter = values.begin(); iter != values.end(); ++iter) {
  175. //对于该取值的训练样例集合
  176. std::vector<int> new_examples;
  177. for (int i = 0; i != cur_examples.size(); ++i) {
  178. if (examples[cur_examples[i]][attributes_num[best_attribute]] == iter->first) {
  179. new_examples.push_back(cur_examples[i]);
  180. }
  181. }
  182.  
  183. //剩余属性
  184. std::vector<int> new_attributes_num;
  185. for (int i = 0; i != attributes_num.size(); ++i) {
  186. if (i != best_attribute) {
  187. new_attributes_num.push_back(attributes_num[i]);
  188. }
  189. }
  190.  
  191. node new_node;
  192. new_node.decision_attribute = iter->first;
  193. if (new_examples.size()) { //若还有未分类属性,则递归分类
  194. DecisionTreeBuild(new_node, new_examples, new_attributes_num);
  195. } else {
  196. if (positive > negative) { //若无剩余属性,则选取样例中目标属性取值频率较高的作为新结点的取值
  197. new_node.attribute = yes;
  198. } else {
  199. new_node.attribute = no;
  200. }
  201. }
  202. root.child.push_back(new_node);
  203. }
  204.  
  205. }
  206.  
  207.  
  208. //打印决策树
  209. void DecisionTreePrint(node &root, int n)
  210. {
  211. for (int i = 0; i != n; ++i) {
  212. std::cout << '\t';
  213. }
  214. if (root.decision_attribute != "") {
  215. std::cout << root.decision_attribute << std::endl;
  216. for (int i = 0; i != n + 1; ++i) {
  217. std::cout << '\t';
  218. }
  219. }
  220. std::cout << root.attribute << std::endl;
  221.  
  222. for (int i = 0; i != root.child.size(); ++i) {
  223. DecisionTreePrint(root.child[i], n + 1);
  224. }
  225. }
  226.  
  227.  
  228.  
  229. int main()
  230. {
  231. std::string cur_s;
  232. std::vector<int> attributes_num;
  233. std::vector<int> cur_examples;
  234. node root;
  235.  
  236. std::cin >> attribute_quantity;
  237. std::cin >> target_attribute_num;
  238. target_attribute_num--;
  239. for (int i = 0; i != attribute_quantity; ++i) {
  240. std::cin >> cur_s;
  241. attributes.push_back(cur_s);
  242. }
  243.  
  244. std::cin >> example_quantity;
  245. for (int i = 0; i != example_quantity; ++i) {
  246. std::vector<std::string> cur_v;
  247. for (int j = 0; j != attribute_quantity; ++j) {
  248. std::cin >> cur_s;
  249. cur_v.push_back(cur_s);
  250. }
  251. examples.push_back(cur_v);
  252. }
  253.  
  254. for (int i = 0; i != attribute_quantity; ++i) {
  255. if (i != target_attribute_num) {
  256. attributes_num.push_back(i);
  257. }
  258. }
  259.  
  260. for (int i = 0; i != example_quantity; ++i) {
  261. cur_examples.push_back(i);
  262. }
  263.  
  264. DecisionTreeBuild(root, cur_examples, attributes_num);
  265.  
  266. DecisionTreePrint(root, 0);
  267.  
  268. return 0;
  269. }
  270.  
Success #stdin #stdout 0.02s 2840KB
stdin
5 
5 
outlook temperature humidity wind playtennis
14
sunny 		hot 	high 	weak 	no
sunny 		hot 	high 	strong 	no
overcast 	hot 	high 	weak 	yes
rainy 		mild 	high 	weak 	yes
rainy 		cool 	normal 	weak 	yes
rainy 		cool 	normal 	strong 	no
overcast 	cool 	normal 	strong 	yes
sunny 		mild 	high 	weak 	no
sunny 		cool 	normal 	weak 	yes
rainy 		mild 	normal 	weak 	yes
sunny 		mild 	normal 	strong 	yes
overcast 	mild 	high 	strong 	yes
overcast 	hot 	normal 	weak 	yes
rainy 		mild 	high 	strong 	no
stdout
outlook
	overcast
		yes
	rainy
		wind
		strong
			no
		weak
			yes
	sunny
		humidity
		high
			no
		normal
			yes