fork(10) download
  1. // ランダムフォレスト多クラス分類(random forest classification)
  2. // なお、データセットはRに付属しているirisを使用しました。
  3.  
  4. #include <cstdio>
  5. #include <cstdlib>
  6. #include <iostream>
  7. #include <vector>
  8. #include <algorithm>
  9. #include <assert.h>
  10.  
  11. using namespace std;
  12. #define SZ(a) ((int)(a).size())
  13.  
  14.  
  15. // 乱数は、xorを使ってますが、メルセンヌツイスターの方がよいかも知れません。
  16. class RandXor
  17. {
  18. public:
  19. RandXor()
  20. {
  21. init();
  22. }
  23.  
  24. void init()
  25. {
  26. x=123456789;
  27. y=362436069;
  28. z=521288629;
  29. w= 88675123;
  30. }
  31.  
  32. inline unsigned int random()
  33. {
  34. unsigned int t;
  35. t=(x^(x<<11));x=y;y=z;z=w; return( w=(w^(w>>19))^(t^(t>>8)) );
  36. }
  37. private:
  38. unsigned int x;
  39. unsigned int y;
  40. unsigned int z;
  41. unsigned int w;
  42. };
  43.  
  44. static RandXor randxor; // マルチスレッド対応にするなら、木ごとに乱数用オブジェクトを用意して、シードを変えましょう。
  45.  
  46.  
  47. typedef double FeatureType;
  48. typedef int AnswerType;
  49. static const int NUM_CLASSES = 3; // 分類のときのクラス数
  50. enum
  51. {
  52. LEFT,
  53. RIGHT,
  54. NUM_LR,
  55. };
  56.  
  57. struct TreeNode {
  58. bool leaf; // 葉(=子がない)ならtrue
  59. int level; // ノードの深さ。ルートノードは0
  60. int featureID; // 説明変数ID。x0, x1, x2... の0,1,2の部分
  61. FeatureType value; // 分割する値
  62. AnswerType answer; // ノード内(=領域内)の目的変数yの平均値
  63. vector <int> bags; // ノード内(=領域内)に含まれるデータのID
  64. int left; // 左側の子のノードID
  65. int right; // 右側の子のノードID
  66.  
  67. TreeNode() {
  68. leaf = false;
  69. level = -1;
  70. featureID = -1;
  71. value = 0;
  72. answer = 0;
  73. left = -1;
  74. right = -1;
  75. }
  76. };
  77.  
  78. class DecisionTree
  79. {
  80. public:
  81. DecisionTree() { }
  82.  
  83. // 学習。訓練データをいれて、決定木を作成する。
  84. // features 説明変数x0,x1,x2...
  85. // answers 目的変数y
  86. // minNodeSize ノード内
  87. // maxLevel ノードの深さの最大値
  88. // numRandomFeatures 領域を分けるときに試す説明変数(グラフでは軸)の数
  89. // numRandomPositions 領域を分けるときに試すデータ(グラフでは点)の数
  90. DecisionTree(const vector <vector <FeatureType> >& features,
  91. const vector <AnswerType>& answers,
  92. int minNodeSize,
  93. int maxLevel,
  94. int numRandomFeatures,
  95. int numRandomPositions)
  96. {
  97. const int numData = SZ(features);
  98. const int numFeatures = SZ(features[0]);
  99. assert(numData==SZ(answers));
  100. assert(numData>1);
  101.  
  102. TreeNode root; // ルートのノード
  103. root.level = 0;
  104.  
  105. root.bags.resize(numData);
  106. for (int i = 0; i < numData; i++)
  107. {
  108. // ここで、同じIDが選ばれる可能性があるが、問題なし。
  109. root.bags[i] = randxor.random()%numData;
  110. }
  111. m_nodes.emplace_back(root);
  112.  
  113. int curNode = 0;
  114. // m_nodesに子ノードがどんどん追加されていく幅優先探索
  115. while (curNode < SZ(m_nodes))
  116. {
  117. TreeNode &node = m_nodes[curNode];
  118.  
  119. // 現在のノードに入っている目的変数が、すべて同じかどうかを調べる
  120. // (その場合は、ノードを分ける必要がなくなる)
  121. bool equal = true; // すべて同じならtrue
  122. for (int i=1;i<SZ(node.bags);i++)
  123. {
  124. if (answers[node.bags[i]] != answers[node.bags[i - 1]])
  125. {
  126. equal = false;
  127. break;
  128. }
  129. }
  130.  
  131. // 葉になる条件のチェック
  132. if (equal || SZ(node.bags) <= minNodeSize || node.level >= maxLevel)
  133. {
  134. // 葉にして子ノードは増やさない。
  135. setLeaf( node, curNode, answers );
  136. continue;
  137. }
  138.  
  139. // どこで分けるのがベストかを調べる
  140. int bestFeatureID = -1;
  141. int bestLeft=0, bestRight=0;
  142. FeatureType bestValue = 0;
  143. double bestGini = 1e99; // ジニ係数
  144.  
  145. for(int i=0;i<numRandomFeatures;i++)
  146. {
  147. // x0,x1,x2...の、どの軸で分けるかを決める
  148. const int featureID = randxor.random()%numFeatures;
  149. for(int j=0;j<numRandomPositions;j++) // どの位置で分けるか
  150. {
  151. const int positionID = randxor.random()%SZ(node.bags);
  152. const FeatureType splitValue = features[node.bags[positionID]][featureID];
  153.  
  154. int total[NUM_LR] = {}; // splitValue未満, splitValue以上の個数
  155. int freq[NUM_LR][NUM_CLASSES]={}; // [どっち側か][クラス] = 個数
  156.  
  157. for (int p : node.bags)
  158. {
  159. int lr = RIGHT;
  160. if (features[p][featureID] < splitValue)
  161. {
  162. lr = LEFT;
  163. }
  164. total[lr]++;
  165. freq[lr][answers[p]]++;
  166. }
  167.  
  168. // nodeBagのデータが"未満"か"以上"のどちらかに全部偏ってるので
  169. // 分け方として意味がないので、すぐやめる。
  170. if (total[LEFT] == 0 || total[RIGHT] == 0)
  171. continue;
  172.  
  173. // ジニ係数を求める(分類用)
  174. double gini = 0;
  175.  
  176. for(int lr = 0; lr < NUM_LR; ++lr)
  177. {
  178. double tmpGini = 1.0;
  179. for(int c=0; c<NUM_CLASSES;++c)
  180. {
  181. double ratio = (double)freq[lr][c]/total[lr];
  182. tmpGini -= ratio*ratio;
  183. }
  184. gini += tmpGini * total[lr]/SZ(node.bags);
  185. }
  186. // BEGIN CUT HERE
  187. // cout << " curNode=" << curNode << " gini=" << gini << endl;
  188. // END CUT HERE
  189.  
  190.  
  191. if (gini < bestGini)
  192. {
  193. bestGini = gini;
  194. bestValue = splitValue;
  195. bestFeatureID = featureID;
  196. bestLeft = total[LEFT];
  197. bestRight = total[RIGHT];
  198. }
  199. }
  200. }
  201.  
  202. // 左か右にどちらかに偏るような分け方しかできなかった場合は、葉にする
  203. // (すべての分け方を試すわけではないので、こういうことは起こりえます)
  204. if (bestLeft == 0 || bestRight == 0)
  205. {
  206. setLeaf( node, curNode, answers );
  207. continue;
  208. }
  209.  
  210. // うまく分けれたので、新しい子ノードを2つ追加する
  211. TreeNode left;
  212. TreeNode right;
  213.  
  214. left.level = right.level = node.level + 1;
  215. node.featureID = bestFeatureID;
  216. node.value = bestValue;
  217. node.left = SZ(m_nodes);
  218. node.right = SZ(m_nodes) + 1;
  219.  
  220. left.bags.resize(bestLeft);
  221. right.bags.resize(bestRight);
  222. for (int p : node.bags)
  223. {
  224. if (features[p][node.featureID] < node.value)
  225. {
  226. left.bags[--bestLeft] = p;
  227. }
  228. else
  229. {
  230. right.bags[--bestRight] = p;
  231. }
  232. }
  233.  
  234. m_nodes.emplace_back(left);
  235. m_nodes.emplace_back(right);
  236. curNode++;
  237. }
  238. }
  239.  
  240. // 予測
  241. // features テスト用の説明変数x0,x1,x2のセット
  242. // 返り値 目的変数yの予測値
  243. AnswerType estimate(const vector <FeatureType>& features) const
  244. {
  245. // ルートからたどるだけ
  246. const TreeNode *pNode = &m_nodes[0];
  247. while (true)
  248. {
  249. if (pNode->leaf)
  250. {
  251. break;
  252. }
  253.  
  254. const int nextNodeID = features[pNode->featureID] < pNode->value ? pNode->left : pNode->right;
  255. pNode = &m_nodes[nextNodeID];
  256. }
  257.  
  258. return pNode->answer;
  259. }
  260.  
  261. private:
  262.  
  263. // nodeを葉にして、curNodeを次のノードへ進める
  264. void setLeaf( TreeNode& node, int& curNode, const vector<AnswerType>& answers ) const
  265. {
  266. node.leaf = true;
  267.  
  268. // 分類の場合は、多数決(ここで平均を使う手法もあるよう)
  269. int freq[NUM_CLASSES]={};
  270. for (int p : node.bags)
  271. {
  272. freq[answers[p]]++;
  273. }
  274. int bestFreq = -1;
  275. int bestC = -1;
  276. for (int c = 0; c < NUM_CLASSES; ++c)
  277. {
  278. if(freq[c]>bestFreq)
  279. {
  280. bestFreq = freq[c];
  281. bestC = c;
  282. }
  283. }
  284.  
  285. node.answer = bestC;
  286.  
  287. curNode++;
  288. }
  289.  
  290. vector < TreeNode > m_nodes; // 決定木のノードたち。m_nodes[0]がルート
  291. };
  292.  
  293. class RandomForest {
  294. public:
  295. RandomForest()
  296. {
  297. clear();
  298. }
  299.  
  300. void clear()
  301. {
  302. m_trees.clear();
  303. }
  304.  
  305. // 訓練
  306. // 繰り返し呼ぶことで木を追加することもできる。
  307. // features 説明変数x0,x1,x2...のセット
  308. // answers 目的変数yのセット
  309. // treesNo      追加する木の数
  310. // minNodeSize ノード内
  311.  
  312. void train(const vector <vector <FeatureType> >& features,
  313. const vector <AnswerType>& answers,
  314. int treesNo,
  315. int minNodeSize)
  316. {
  317. for(int i=0;i<treesNo;i++)
  318. {
  319. m_trees.emplace_back(DecisionTree(features, answers, minNodeSize, 16, 2, 5));
  320. }
  321. }
  322.  
  323.  
  324. // 分類の予測
  325. // features テスト用の説明変数x0,x1,x2のセット
  326. // 返り値 目的変数yの予測値
  327. AnswerType estimateClassification(vector <FeatureType> &features)
  328. {
  329. if (SZ(m_trees) == 0)
  330. {
  331. return 0;
  332. }
  333.  
  334. // 多数決
  335. int freq[NUM_CLASSES]={};
  336. for(int i=0;i<SZ(m_trees);i++)
  337. {
  338. freq[m_trees[i].estimate(features)]++;
  339. }
  340.  
  341. int bestFreq = -1;
  342. int bestC = -1;
  343. for (int c = 0; c < NUM_CLASSES; ++c)
  344. {
  345. if(freq[c]>bestFreq)
  346. {
  347. bestFreq = freq[c];
  348. bestC = c;
  349. }
  350. }
  351.  
  352. return bestC;
  353. }
  354.  
  355. private:
  356. vector < DecisionTree > m_trees; // たくさんの決定木
  357. };
  358.  
  359. int main()
  360. {
  361. int numAll; // 全データ数
  362. int numTrainings; // 訓練データ数
  363. int numTests; // テストデータ数
  364. int numFeatures; // 説明変数の数
  365.  
  366. // y = f(x0,x1,x2,...)
  367. // x0,x1,x2は説明変数です。コード上ではfeatureと命名してます。
  368. // yは目的変数です。コード上ではanswerという命名をしてます。
  369.  
  370.  
  371. cin >> numAll >> numTrainings >> numTests >> numFeatures;
  372. assert(numTrainings+numTests<=numAll);
  373.  
  374. // 全データ
  375. vector < vector <FeatureType> > allFeatures(numAll, vector <FeatureType> (numFeatures));
  376. vector < AnswerType > allAnswers(numAll);
  377.  
  378. for(int i = 0 ; i < numAll; ++i)
  379. {
  380. for (int k = 0; k < numFeatures; ++k)
  381. {
  382. cin >> allFeatures[i][k];
  383. }
  384. cin >> allAnswers[i];
  385. assert(allAnswers[i]>=0);
  386. assert(allAnswers[i]<NUM_CLASSES);
  387. }
  388.  
  389. // シャッフル用
  390. vector < int > shuffleTable;
  391. for (int i = 0; i < numTrainings+numTests; ++i)
  392. {
  393. shuffleTable.emplace_back(i);
  394. }
  395. random_shuffle(shuffleTable.begin(), shuffleTable.end());
  396.  
  397. // 訓練データ
  398. vector < vector <FeatureType> > trainingFeatures(numTrainings, vector <FeatureType>(numFeatures));
  399. vector < AnswerType > trainingAnswers(numTrainings);
  400. for (int i = 0; i < numTrainings; ++i)
  401. {
  402. trainingFeatures[i] = allFeatures[shuffleTable[i]];
  403. trainingAnswers[i] = allAnswers[shuffleTable[i]];
  404. }
  405.  
  406. // テストデータ
  407. vector < vector <FeatureType> > testFeatures(numTests, vector <FeatureType>(numFeatures));
  408. vector < AnswerType > testAnswers(numTests);
  409. for (int i = 0; i < numTests; ++i)
  410. {
  411. testFeatures[i] = allFeatures[shuffleTable[numTrainings+i]];
  412. testAnswers[i] = allAnswers[shuffleTable[numTrainings+i]];
  413. }
  414.  
  415. // ランダムフォレストを使って予測
  416. RandomForest* rf = new RandomForest();
  417.  
  418. // 木を徐々に増やしていく
  419. int numTrees = 0;
  420. for (int k = 0; k < 20; ++k)
  421. {
  422. // 学習
  423. const int numAdditionalTrees = 1;
  424. rf->train(trainingFeatures, trainingAnswers, numAdditionalTrees, 1);
  425. numTrees += numAdditionalTrees;
  426.  
  427. // 予測と結果表示
  428. cout << "-----" << endl;
  429. cout << "numTrees=" << numTrees << endl;
  430. double totalError = 0.0;
  431. for (int i = 0; i < numTests; ++i)
  432. {
  433. const AnswerType myAnswer = rf->estimateClassification(testFeatures[i]);
  434. int diff = 0;
  435. if(myAnswer!=testAnswers[i])
  436. {
  437. // cout << "Failure! i=" << i << " myAnswer=" << myAnswer << " testAnswer=" << testAnswers[i] << endl;
  438. diff = 1;
  439. }
  440. totalError += diff;
  441. }
  442. cout << "totalError=" << totalError << endl;
  443. }
  444.  
  445. delete rf;
  446.  
  447. return 0;
  448. }
  449.  
  450.  
Success #stdin #stdout 0s 3444KB
stdin
150 25 125 4
5.1 3.5 1.4 0.2 0
4.9 3   1.4 0.2 0
4.7 3.2 1.3 0.2 0
4.6 3.1 1.5 0.2 0
5   3.6 1.4 0.2 0
5.4 3.9 1.7 0.4 0
4.6 3.4 1.4 0.3 0
5   3.4 1.5 0.2 0
4.4 2.9 1.4 0.2 0
4.9 3.1 1.5 0.1 0
5.4 3.7 1.5 0.2 0
4.8 3.4 1.6 0.2 0
4.8 3   1.4 0.1 0
4.3 3   1.1 0.1 0
5.8 4   1.2 0.2 0
5.7 4.4 1.5 0.4 0
5.4 3.9 1.3 0.4 0
5.1 3.5 1.4 0.3 0
5.7 3.8 1.7 0.3 0
5.1 3.8 1.5 0.3 0
5.4 3.4 1.7 0.2 0
5.1 3.7 1.5 0.4 0
4.6 3.6 1   0.2 0
5.1 3.3 1.7 0.5 0
4.8 3.4 1.9 0.2 0
5   3   1.6 0.2 0
5   3.4 1.6 0.4 0
5.2 3.5 1.5 0.2 0
5.2 3.4 1.4 0.2 0
4.7 3.2 1.6 0.2 0
4.8 3.1 1.6 0.2 0
5.4 3.4 1.5 0.4 0
5.2 4.1 1.5 0.1 0
5.5 4.2 1.4 0.2 0
4.9 3.1 1.5 0.2 0
5   3.2 1.2 0.2 0
5.5 3.5 1.3 0.2 0
4.9 3.6 1.4 0.1 0
4.4 3   1.3 0.2 0
5.1 3.4 1.5 0.2 0
5   3.5 1.3 0.3 0
4.5 2.3 1.3 0.3 0
4.4 3.2 1.3 0.2 0
5   3.5 1.6 0.6 0
5.1 3.8 1.9 0.4 0
4.8 3   1.4 0.3 0
5.1 3.8 1.6 0.2 0
4.6 3.2 1.4 0.2 0
5.3 3.7 1.5 0.2 0
5   3.3 1.4 0.2 0
7   3.2 4.7 1.4 1
6.4 3.2 4.5 1.5 1
6.9 3.1 4.9 1.5 1
5.5 2.3 4   1.3 1
6.5 2.8 4.6 1.5 1
5.7 2.8 4.5 1.3 1
6.3 3.3 4.7 1.6 1
4.9 2.4 3.3 1   1
6.6 2.9 4.6 1.3 1
5.2 2.7 3.9 1.4 1
5   2   3.5 1   1
5.9 3   4.2 1.5 1
6   2.2 4   1   1
6.1 2.9 4.7 1.4 1
5.6 2.9 3.6 1.3 1
6.7 3.1 4.4 1.4 1
5.6 3   4.5 1.5 1
5.8 2.7 4.1 1   1
6.2 2.2 4.5 1.5 1
5.6 2.5 3.9 1.1 1
5.9 3.2 4.8 1.8 1
6.1 2.8 4   1.3 1
6.3 2.5 4.9 1.5 1
6.1 2.8 4.7 1.2 1
6.4 2.9 4.3 1.3 1
6.6 3   4.4 1.4 1
6.8 2.8 4.8 1.4 1
6.7 3   5   1.7 1
6   2.9 4.5 1.5 1
5.7 2.6 3.5 1   1
5.5 2.4 3.8 1.1 1
5.5 2.4 3.7 1   1
5.8 2.7 3.9 1.2 1
6   2.7 5.1 1.6 1
5.4 3   4.5 1.5 1
6   3.4 4.5 1.6 1
6.7 3.1 4.7 1.5 1
6.3 2.3 4.4 1.3 1
5.6 3   4.1 1.3 1
5.5 2.5 4   1.3 1
5.5 2.6 4.4 1.2 1
6.1 3   4.6 1.4 1
5.8 2.6 4   1.2 1
5   2.3 3.3 1   1
5.6 2.7 4.2 1.3 1
5.7 3   4.2 1.2 1
5.7 2.9 4.2 1.3 1
6.2 2.9 4.3 1.3 1
5.1 2.5 3   1.1 1
5.7 2.8 4.1 1.3 1
6.3 3.3 6   2.5 2
5.8 2.7 5.1 1.9 2
7.1 3   5.9 2.1 2
6.3 2.9 5.6 1.8 2
6.5 3   5.8 2.2 2
7.6 3   6.6 2.1 2
4.9 2.5 4.5 1.7 2
7.3 2.9 6.3 1.8 2
6.7 2.5 5.8 1.8 2
7.2 3.6 6.1 2.5 2
6.5 3.2 5.1 2   2
6.4 2.7 5.3 1.9 2
6.8 3   5.5 2.1 2
5.7 2.5 5   2   2
5.8 2.8 5.1 2.4 2
6.4 3.2 5.3 2.3 2
6.5 3   5.5 1.8 2
7.7 3.8 6.7 2.2 2
7.7 2.6 6.9 2.3 2
6   2.2 5   1.5 2
6.9 3.2 5.7 2.3 2
5.6 2.8 4.9 2   2
7.7 2.8 6.7 2   2
6.3 2.7 4.9 1.8 2
6.7 3.3 5.7 2.1 2
7.2 3.2 6   1.8 2
6.2 2.8 4.8 1.8 2
6.1 3   4.9 1.8 2
6.4 2.8 5.6 2.1 2
7.2 3   5.8 1.6 2
7.4 2.8 6.1 1.9 2
7.9 3.8 6.4 2   2
6.4 2.8 5.6 2.2 2
6.3 2.8 5.1 1.5 2
6.1 2.6 5.6 1.4 2
7.7 3   6.1 2.3 2
6.3 3.4 5.6 2.4 2
6.4 3.1 5.5 1.8 2
6   3   4.8 1.8 2
6.9 3.1 5.4 2.1 2
6.7 3.1 5.6 2.4 2
6.9 3.1 5.1 2.3 2
5.8 2.7 5.1 1.9 2
6.8 3.2 5.9 2.3 2
6.7 3.3 5.7 2.5 2
6.7 3   5.2 2.3 2
6.3 2.5 5   1.9 2
6.5 3   5.2 2   2
6.2 3.4 5.4 2.3 2
5.9 3   5.1 1.8 2
stdout
-----
numTrees=1
totalError=48
-----
numTrees=2
totalError=39
-----
numTrees=3
totalError=15
-----
numTrees=4
totalError=14
-----
numTrees=5
totalError=8
-----
numTrees=6
totalError=12
-----
numTrees=7
totalError=10
-----
numTrees=8
totalError=11
-----
numTrees=9
totalError=8
-----
numTrees=10
totalError=8
-----
numTrees=11
totalError=8
-----
numTrees=12
totalError=8
-----
numTrees=13
totalError=8
-----
numTrees=14
totalError=8
-----
numTrees=15
totalError=8
-----
numTrees=16
totalError=8
-----
numTrees=17
totalError=8
-----
numTrees=18
totalError=8
-----
numTrees=19
totalError=8
-----
numTrees=20
totalError=8