fork(7) download
  1. // ランダムフォレスト回帰(random forest regression)
  2. // なお、データセットはRに付属しているorangeを使用しました。
  3.  
  4. #include <cstdio>
  5. #include <cstdlib>
  6. #include <iostream>
  7. #include <vector>
  8. #include <algorithm>
  9. #include <assert.h>
  10.  
  11. using namespace std;
  12. #define SZ(a) ((int)(a).size())
  13.  
  14.  
  15. // 乱数は、xorを使ってますが、メルセンヌツイスターの方がよいかも知れません。
  16. class RandXor
  17. {
  18. public:
  19. RandXor()
  20. {
  21. init();
  22. }
  23.  
  24. void init()
  25. {
  26. x=123456789;
  27. y=362436069;
  28. z=521288629;
  29. w= 88675123;
  30. }
  31.  
  32. inline unsigned int random()
  33. {
  34. unsigned int t;
  35. t=(x^(x<<11));x=y;y=z;z=w; return( w=(w^(w>>19))^(t^(t>>8)) );
  36. }
  37. private:
  38. unsigned int x;
  39. unsigned int y;
  40. unsigned int z;
  41. unsigned int w;
  42. };
  43.  
  44. static RandXor randxor; // マルチスレッド対応にするなら、木ごとに乱数用オブジェクトを用意して、シードを変えましょう。
  45.  
  46.  
  47. typedef double FeatureType;
  48. typedef double AnswerType;
  49.  
  50. struct TreeNode {
  51. bool leaf; // 葉(=子がない)ならtrue
  52. int level; // ノードの深さ。ルートノードは0
  53. int featureID; // 説明変数ID。x0, x1, x2... の0,1,2の部分
  54. FeatureType value; // 分割する値
  55. AnswerType answer; // ノード内(=領域内)の目的変数yの平均値
  56. vector <int> bags; // ノード内(=領域内)に含まれるデータのID
  57. int left; // 左側の子のノードID
  58. int right; // 右側の子のノードID
  59.  
  60. TreeNode() {
  61. leaf = false;
  62. level = -1;
  63. featureID = -1;
  64. value = 0;
  65. answer = 0;
  66. left = -1;
  67. right = -1;
  68. }
  69. };
  70.  
  71. class DecisionTree
  72. {
  73. public:
  74. DecisionTree() { }
  75.  
  76. // 学習。訓練データをいれて、決定木を作成する。
  77. // features 説明変数x0,x1,x2...
  78. // answers 目的変数y
  79. // minNodeSize ノード内
  80. // maxLevel ノードの深さの最大値
  81. // numRandomFeatures 領域を分けるときに試す説明変数(グラフでは軸)の数
  82. // numRandomPositions 領域を分けるときに試すデータ(グラフでは点)の数
  83. DecisionTree(const vector <vector <FeatureType> >& features,
  84. const vector <AnswerType>& answers,
  85. int minNodeSize,
  86. int maxLevel,
  87. int numRandomFeatures,
  88. int numRandomPositions)
  89. {
  90. const int numData = SZ(features);
  91. const int numFeatures = SZ(features[0]);
  92. assert(numData==SZ(answers));
  93. assert(numData>1);
  94.  
  95. TreeNode root; // ルートのノード
  96. root.level = 0;
  97.  
  98. root.bags.resize(numData);
  99. for (int i = 0; i < numData; i++)
  100. {
  101. // ここで、同じIDが選ばれる可能性があるが、問題なし。
  102. root.bags[i] = randxor.random()%numData;
  103. }
  104. m_nodes.emplace_back(root);
  105.  
  106. int curNode = 0;
  107. // m_nodesに子ノードがどんどん追加されていく幅優先探索
  108. while (curNode < SZ(m_nodes))
  109. {
  110. TreeNode &node = m_nodes[curNode];
  111.  
  112. // 現在のノードに入っている目的変数が、すべて同じかどうかを調べる
  113. // (その場合は、ノードを分ける必要がなくなる)
  114. bool equal = true; // すべて同じならtrue
  115. for (int i=1;i<SZ(node.bags);i++)
  116. {
  117. if (answers[node.bags[i]] != answers[node.bags[i - 1]])
  118. {
  119. equal = false;
  120. break;
  121. }
  122. }
  123.  
  124. // 葉になる条件のチェック
  125. if (equal || SZ(node.bags) <= minNodeSize || node.level >= maxLevel)
  126. {
  127. // 葉にして子ノードは増やさない。
  128. setLeaf( node, curNode, answers );
  129. continue;
  130. }
  131.  
  132. // どこで分けるのがベストかを調べる
  133. int bestFeatureID = -1;
  134. int bestLeft=0, bestRight=0;
  135. FeatureType bestValue = 0;
  136. double bestMSE = 1e99; // 平均との2乗の差の和
  137.  
  138. for(int i=0;i<numRandomFeatures;i++)
  139. {
  140. // x0,x1,x2...の、どの軸で分けるかを決める
  141. const int featureID = randxor.random()%numFeatures;
  142. for(int j=0;j<numRandomPositions;j++) // どの位置で分けるか
  143. {
  144. const int positionID = randxor.random()%SZ(node.bags);
  145. const FeatureType splitValue = features[node.bags[positionID]][featureID];
  146. double sumLeft = 0;
  147. double sumRight = 0;
  148. int totalLeft = 0; // splitValue未満の個数
  149. int totalRight = 0; // splitValue以上の個数
  150. for (int p : node.bags)
  151. {
  152. if (features[p][featureID] < splitValue)
  153. {
  154. sumLeft += answers[p];
  155. totalLeft++;
  156. } else {
  157. sumRight += answers[p];
  158. totalRight++;
  159. }
  160. }
  161.  
  162. // nodeBagのデータが"未満"か"以上"のどちらかに全部偏ってるので
  163. // 分け方として意味がないので、すぐやめる。
  164. if (totalLeft == 0 || totalRight == 0)
  165. continue;
  166.  
  167. // 平均との差を求める(回帰用)
  168. double meanLeft = totalLeft == 0 ? 0 : sumLeft / totalLeft;
  169. double meanRight = totalRight == 0 ? 0 : sumRight / totalRight;
  170.  
  171. double mse = 0;
  172. for (int p : node.bags)
  173. {
  174. if (features[p][featureID] < splitValue)
  175. {
  176. mse += (answers[p] - meanLeft) * (answers[p] - meanLeft);
  177. }
  178. else
  179. {
  180. mse += (answers[p] - meanRight) * (answers[p] - meanRight);
  181. }
  182. }
  183.  
  184. if (mse < bestMSE)
  185. {
  186. bestMSE = mse;
  187. bestValue = splitValue;
  188. bestFeatureID = featureID;
  189. bestLeft = totalLeft;
  190. bestRight = totalRight;
  191. }
  192. }
  193. }
  194.  
  195. // 左か右にどちらかに偏るような分け方しかできなかった場合は、葉にする
  196. // (すべての分け方を試すわけではないので、こういうことは起こりえます)
  197. if (bestLeft == 0 || bestRight == 0)
  198. {
  199. setLeaf( node, curNode, answers );
  200. continue;
  201. }
  202.  
  203. // うまく分けれたので、新しい子ノードを2つ追加する
  204. TreeNode left;
  205. TreeNode right;
  206.  
  207. left.level = right.level = node.level + 1;
  208. node.featureID = bestFeatureID;
  209. node.value = bestValue;
  210. node.left = SZ(m_nodes);
  211. node.right = SZ(m_nodes) + 1;
  212.  
  213. left.bags.resize(bestLeft);
  214. right.bags.resize(bestRight);
  215. for (int p : node.bags)
  216. {
  217. if (features[p][node.featureID] < node.value)
  218. {
  219. left.bags[--bestLeft] = p;
  220. }
  221. else
  222. {
  223. right.bags[--bestRight] = p;
  224. }
  225. }
  226.  
  227. m_nodes.emplace_back(left);
  228. m_nodes.emplace_back(right);
  229. curNode++;
  230. }
  231. }
  232.  
  233. // 予測
  234. // features テスト用の説明変数x0,x1,x2のセット
  235. // 返り値 目的変数yの予測値
  236. AnswerType estimate(const vector <FeatureType>& features) const
  237. {
  238. // ルートからたどるだけ
  239. const TreeNode *pNode = &m_nodes[0];
  240. while (true)
  241. {
  242. if (pNode->leaf)
  243. {
  244. break;
  245. }
  246.  
  247. const int nextNodeID = features[pNode->featureID] < pNode->value ? pNode->left : pNode->right;
  248. pNode = &m_nodes[nextNodeID];
  249. }
  250.  
  251. return pNode->answer;
  252. }
  253.  
  254. private:
  255.  
  256. // nodeを葉にして、curNodeを次のノードへ進める
  257. void setLeaf( TreeNode& node, int& curNode, const vector<AnswerType>& answers ) const
  258. {
  259. node.leaf = true;
  260.  
  261. // 回帰の場合は、目的変数yの平均を求める
  262. for (int p : node.bags)
  263. {
  264. node.answer += answers[p];
  265. }
  266.  
  267. assert(SZ(node.bags) > 0);
  268. if (SZ(node.bags))
  269. {
  270. node.answer /= SZ(node.bags);
  271. }
  272. curNode++;
  273. }
  274.  
  275. vector < TreeNode > m_nodes; // 決定木のノードたち。m_nodes[0]がルート
  276. };
  277.  
  278. class RandomForest {
  279. public:
  280. RandomForest()
  281. {
  282. clear();
  283. }
  284.  
  285. void clear()
  286. {
  287. m_trees.clear();
  288. }
  289.  
  290. // 訓練
  291. // 繰り返し呼ぶことで木を追加することもできる。
  292. // features 説明変数x0,x1,x2...のセット
  293. // answers 目的変数yのセット
  294. // treesNo      追加する木の数
  295. // minNodeSize ノード内
  296.  
  297. void train(const vector <vector <FeatureType> >& features,
  298. const vector <AnswerType>& answers,
  299. int treesNo,
  300. int minNodeSize)
  301. {
  302. for(int i=0;i<treesNo;i++)
  303. {
  304. m_trees.emplace_back(DecisionTree(features, answers, minNodeSize, 16, 2, 5));
  305. }
  306. }
  307.  
  308.  
  309. // 回帰の予測
  310. // features テスト用の説明変数x0,x1,x2のセット
  311. // 返り値 目的変数yの予測値
  312. AnswerType estimateRegression(vector <FeatureType> &features)
  313. {
  314. if (SZ(m_trees) == 0)
  315. {
  316. return 0.0;
  317. }
  318.  
  319. // すべての木から得られた結果の平均をとるだけ
  320. double sum = 0;
  321. for(int i=0;i<SZ(m_trees);i++)
  322. {
  323. sum += m_trees[i].estimate(features);
  324. }
  325. return sum / SZ(m_trees);
  326. }
  327.  
  328. private:
  329. vector < DecisionTree > m_trees; // たくさんの決定木
  330. };
  331.  
  332. int main()
  333. {
  334. int numAll; // 全データ数
  335. int numTrainings; // 訓練データ数
  336. int numTests; // テストデータ数
  337. int numFeatures; // 説明変数の数
  338.  
  339. // y = f(x0,x1,x2,...)
  340. // x0,x1,x2は説明変数です。コード上ではfeatureと命名してます。
  341. // yは目的変数です。コード上ではanswerという命名をしてます。
  342.  
  343.  
  344. cin >> numAll >> numTrainings >> numTests >> numFeatures;
  345. assert(numTrainings+numTests<=numAll);
  346.  
  347. // 全データ
  348. vector < vector <FeatureType> > allFeatures(numAll, vector <FeatureType> (numFeatures));
  349. vector < AnswerType > allAnswers(numAll);
  350.  
  351. for(int i = 0 ; i < numAll; ++i)
  352. {
  353. for (int k = 0; k < numFeatures; ++k)
  354. {
  355. cin >> allFeatures[i][k];
  356. }
  357. cin >> allAnswers[i];
  358. }
  359.  
  360. // シャッフル用
  361. vector < int > shuffleTable;
  362. for (int i = 0; i < numTrainings+numTests; ++i)
  363. {
  364. shuffleTable.emplace_back(i);
  365. }
  366. random_shuffle(shuffleTable.begin(), shuffleTable.end());
  367.  
  368. // 訓練データ
  369. vector < vector <FeatureType> > trainingFeatures(numTrainings, vector <FeatureType>(numFeatures));
  370. vector < AnswerType > trainingAnswers(numTrainings);
  371. for (int i = 0; i < numTrainings; ++i)
  372. {
  373. trainingFeatures[i] = allFeatures[shuffleTable[i]];
  374. trainingAnswers[i] = allAnswers[shuffleTable[i]];
  375. }
  376.  
  377. // テストデータ
  378. vector < vector <FeatureType> > testFeatures(numTests, vector <FeatureType>(numFeatures));
  379. vector < AnswerType > testAnswers(numTests);
  380. for (int i = 0; i < numTests; ++i)
  381. {
  382. testFeatures[i] = allFeatures[shuffleTable[numTrainings+i]];
  383. testAnswers[i] = allAnswers[shuffleTable[numTrainings+i]];
  384. }
  385.  
  386. // ランダムフォレストを使って予測
  387. RandomForest* rf = new RandomForest();
  388.  
  389. // 木を徐々に増やしていく
  390. int numTrees = 0;
  391. for (int k = 0; k < 20; ++k)
  392. {
  393. // 学習
  394. const int numAdditionalTrees = 1;
  395. rf->train(trainingFeatures, trainingAnswers, numAdditionalTrees, 1);
  396. numTrees += numAdditionalTrees;
  397.  
  398. // 予測と結果表示
  399. cout << "-----" << endl;
  400. cout << "numTrees=" << numTrees << endl;
  401. double totalError = 0.0;
  402. for (int i = 0; i < numTests; ++i)
  403. {
  404. const double myAnswer = rf->estimateRegression(testFeatures[i]);
  405. const double diff = myAnswer-testAnswers[i];
  406. totalError += abs(diff);
  407. cout << " myAnswer=" << myAnswer << " testAnswer=" << testAnswers[i] << " diff=" << diff << endl;
  408. }
  409. cout << "totalError=" << totalError << endl;
  410. }
  411.  
  412. delete rf;
  413.  
  414. return 0;
  415. }
  416.  
  417.  
Success #stdin #stdout 0s 3488KB
stdin
35 30 5 2
1  118  30
1  484  58
1  664  87
1 1004 115
1 1231 120
1 1372 142
1 1582 145
2  118  33
2  484  69
2  664 111
2 1004 156
2 1231 172
2 1372 203
2 1582 203
3  118  30
3  484  51
3  664  75
3 1004 108
3 1231 115
3 1372 139
3 1582 140
4  118  32
4  484  62
4  664 112
4 1004 167
4 1231 179
4 1372 209
4 1582 214
5  118  30
5  484  49
5  664  81
5 1004 125
5 1231 142
5 1372 174
5 1582 177
stdout
-----
numTrees=1
 myAnswer=125 testAnswer=81 diff=44
 myAnswer=156 testAnswer=111 diff=45
 myAnswer=203 testAnswer=203 diff=0
 myAnswer=174 testAnswer=177 diff=-3
 myAnswer=103.5 testAnswer=115 diff=-11.5
totalError=103.5
-----
numTrees=2
 myAnswer=100 testAnswer=81 diff=19
 myAnswer=115.5 testAnswer=111 diff=4.5
 myAnswer=203 testAnswer=203 diff=0
 myAnswer=194 testAnswer=177 diff=17
 myAnswer=111.75 testAnswer=115 diff=-3.25
totalError=43.75
-----
numTrees=3
 myAnswer=76.6667 testAnswer=81 diff=-4.33333
 myAnswer=97.5556 testAnswer=111 diff=-13.4444
 myAnswer=187.333 testAnswer=203 diff=-15.6667
 myAnswer=187.333 testAnswer=177 diff=10.3333
 myAnswer=121.833 testAnswer=115 diff=6.83333
totalError=50.6111
-----
numTrees=4
 myAnswer=71.125 testAnswer=81 diff=-9.875
 myAnswer=86.7917 testAnswer=111 diff=-24.2083
 myAnswer=176.75 testAnswer=203 diff=-26.25
 myAnswer=184 testAnswer=177 diff=7
 myAnswer=121.375 testAnswer=115 diff=6.375
totalError=73.7083
-----
numTrees=5
 myAnswer=74.3 testAnswer=81 diff=-6.7
 myAnswer=86.8333 testAnswer=111 diff=-24.1667
 myAnswer=170.4 testAnswer=203 diff=-32.6
 myAnswer=190 testAnswer=177 diff=13
 myAnswer=130.5 testAnswer=115 diff=15.5
totalError=91.9667
-----
numTrees=6
 myAnswer=68.5 testAnswer=81 diff=-12.5
 myAnswer=83.8611 testAnswer=111 diff=-27.1389
 myAnswer=170.667 testAnswer=203 diff=-32.3333
 myAnswer=187.333 testAnswer=177 diff=10.3333
 myAnswer=134.75 testAnswer=115 diff=19.75
totalError=102.056
-----
numTrees=7
 myAnswer=70.7143 testAnswer=81 diff=-10.2857
 myAnswer=83.881 testAnswer=111 diff=-27.119
 myAnswer=169.81 testAnswer=203 diff=-33.1905
 myAnswer=184.095 testAnswer=177 diff=7.09524
 myAnswer=137.786 testAnswer=115 diff=22.7857
totalError=100.476
-----
numTrees=8
 myAnswer=72.75 testAnswer=81 diff=-8.25
 myAnswer=84.2708 testAnswer=111 diff=-26.7292
 myAnswer=168.083 testAnswer=203 diff=-34.9167
 myAnswer=182.833 testAnswer=177 diff=5.83333
 myAnswer=140.062 testAnswer=115 diff=25.0625
totalError=100.792
-----
numTrees=9
 myAnswer=74.0278 testAnswer=81 diff=-6.97222
 myAnswer=92.2407 testAnswer=111 diff=-18.7593
 myAnswer=171.269 testAnswer=203 diff=-31.7315
 myAnswer=184.38 testAnswer=177 diff=7.37963
 myAnswer=141.833 testAnswer=115 diff=26.8333
totalError=91.6759
-----
numTrees=10
 myAnswer=77.825 testAnswer=81 diff=-3.175
 myAnswer=98.6167 testAnswer=111 diff=-12.3833
 myAnswer=168.517 testAnswer=203 diff=-34.4833
 myAnswer=187.342 testAnswer=177 diff=10.3417
 myAnswer=133.45 testAnswer=115 diff=18.45
totalError=78.8333
-----
numTrees=11
 myAnswer=80.9318 testAnswer=81 diff=-0.0681818
 myAnswer=96.4697 testAnswer=111 diff=-14.5303
 myAnswer=166.379 testAnswer=203 diff=-36.6212
 myAnswer=183.038 testAnswer=177 diff=6.03788
 myAnswer=135.5 testAnswer=115 diff=20.5
totalError=77.7576
-----
numTrees=12
 myAnswer=82.4792 testAnswer=81 diff=1.47917
 myAnswer=96.7222 testAnswer=111 diff=-14.2778
 myAnswer=164.181 testAnswer=203 diff=-38.8194
 myAnswer=185.618 testAnswer=177 diff=8.61806
 myAnswer=134.208 testAnswer=115 diff=19.2083
totalError=82.4028
-----
numTrees=13
 myAnswer=81.4423 testAnswer=81 diff=0.442308
 myAnswer=94.5897 testAnswer=111 diff=-16.4103
 myAnswer=162.705 testAnswer=203 diff=-40.2949
 myAnswer=184.724 testAnswer=177 diff=7.72436
 myAnswer=134.808 testAnswer=115 diff=19.8077
totalError=84.6795
-----
numTrees=14
 myAnswer=77.8393 testAnswer=81 diff=-3.16071
 myAnswer=92.7619 testAnswer=111 diff=-18.2381
 myAnswer=163.369 testAnswer=203 diff=-39.631
 myAnswer=186.815 testAnswer=177 diff=9.81548
 myAnswer=135.583 testAnswer=115 diff=20.5833
totalError=91.4286
-----
numTrees=15
 myAnswer=75.9167 testAnswer=81 diff=-5.08333
 myAnswer=92.3778 testAnswer=111 diff=-18.6222
 myAnswer=162.144 testAnswer=203 diff=-40.8556
 myAnswer=188.628 testAnswer=177 diff=11.6278
 myAnswer=136.944 testAnswer=115 diff=21.9444
totalError=98.1333
-----
numTrees=16
 myAnswer=74.2344 testAnswer=81 diff=-6.76562
 myAnswer=96.3542 testAnswer=111 diff=-14.6458
 myAnswer=164.698 testAnswer=203 diff=-38.3021
 myAnswer=184.651 testAnswer=177 diff=7.65104
 myAnswer=133.823 testAnswer=115 diff=18.8229
totalError=86.1875
-----
numTrees=17
 myAnswer=76.4559 testAnswer=81 diff=-4.54412
 myAnswer=95.8039 testAnswer=111 diff=-15.1961
 myAnswer=163.539 testAnswer=203 diff=-39.4608
 myAnswer=186.377 testAnswer=177 diff=9.37745
 myAnswer=131.069 testAnswer=115 diff=16.0686
totalError=84.6471
-----
numTrees=18
 myAnswer=78.4306 testAnswer=81 diff=-2.56944
 myAnswer=99.1481 testAnswer=111 diff=-11.8519
 myAnswer=162.231 testAnswer=203 diff=-40.7685
 myAnswer=185.69 testAnswer=177 diff=8.68981
 myAnswer=132.454 testAnswer=115 diff=17.4537
totalError=81.3333
-----
numTrees=19
 myAnswer=76.8816 testAnswer=81 diff=-4.11842
 myAnswer=97.2456 testAnswer=111 diff=-13.7544
 myAnswer=162.746 testAnswer=203 diff=-40.2544
 myAnswer=183.39 testAnswer=177 diff=6.39035
 myAnswer=132.456 testAnswer=115 diff=17.4561
totalError=81.9737
-----
numTrees=20
 myAnswer=74.5375 testAnswer=81 diff=-6.4625
 myAnswer=95.8333 testAnswer=111 diff=-15.1667
 myAnswer=163.725 testAnswer=203 diff=-39.275
 myAnswer=184.921 testAnswer=177 diff=7.92083
 myAnswer=128.733 testAnswer=115 diff=13.7333
totalError=82.5583