fork download
  1. // ランダムフォレスト回帰(random forest regression)
  2. // なお、データセットはRに付属しているorangeを使用しました。
  3.  
  4. #include <cstdio>
  5. #include <cstdlib>
  6. #include <iostream>
  7. #include <vector>
  8. #include <algorithm>
  9. #include <assert.h>
  10. #include <cmath>
  11.  
  12. using namespace std;
  13. #define SZ(a) ((int)(a).size())
  14.  
  15.  
  16. // 乱数は、xorを使ってますが、メルセンヌツイスターの方がよいかも知れません。
  17. class RandXor
  18. {
  19. public:
  20. RandXor()
  21. {
  22. init();
  23. }
  24.  
  25. void init()
  26. {
  27. x=123456789;
  28. y=362436069;
  29. z=521288629;
  30. w= 88675123;
  31. }
  32.  
  33. inline unsigned int random()
  34. {
  35. unsigned int t;
  36. t=(x^(x<<11));x=y;y=z;z=w; return( w=(w^(w>>19))^(t^(t>>8)) );
  37. }
  38.  
  39. inline void randomShuffle(vector <int> &a) {
  40. const int n = SZ(a);
  41. for (int i = n - 1; i > 0; --i) {
  42. swap(a[i], a[random()%(i + 1)]);
  43. }
  44. }
  45. private:
  46. unsigned int x;
  47. unsigned int y;
  48. unsigned int z;
  49. unsigned int w;
  50. };
  51.  
  52. static RandXor randxor; // マルチスレッド対応にするなら、木ごとに乱数用オブジェクトを用意して、シードを変えましょう。
  53.  
  54.  
  55. typedef double FeatureType;
  56. typedef double AnswerType;
  57.  
  58. struct TreeNode {
  59. bool leaf; // 葉(=子がない)ならtrue
  60. int level; // ノードの深さ。ルートノードは0
  61. int featureID; // 説明変数ID。x0, x1, x2... の0,1,2の部分
  62. FeatureType value; // 分割する値
  63. AnswerType answer; // ノード内(=領域内)の目的変数yの平均値
  64. vector <int> bags; // ノード内(=領域内)に含まれるデータのID
  65. int left; // 左側の子のノードID
  66. int right; // 右側の子のノードID
  67.  
  68. TreeNode() {
  69. leaf = false;
  70. level = -1;
  71. featureID = -1;
  72. value = 0;
  73. answer = 0;
  74. left = -1;
  75. right = -1;
  76. }
  77. };
  78.  
  79. class DecisionTree
  80. {
  81. public:
  82. DecisionTree() { }
  83.  
  84. // 学習。訓練データをいれて、決定木を作成する。
  85. // features 説明変数x0,x1,x2...
  86. // answers 目的変数y
  87. // minNodeSize ノード内
  88. // maxLevel ノードの深さの最大値
  89. // numRandomFeatures 領域を分けるときに試す説明変数(グラフでは軸)の数
  90. // numRandomPositions 領域を分けるときに試すデータ(グラフでは点)の数
  91. DecisionTree(const vector <vector <FeatureType> >& features,
  92. const vector <AnswerType>& answers,
  93. int minNodeSize,
  94. int maxLevel,
  95. int numRandomFeatures,
  96. int numRandomPositions)
  97. {
  98. const int numData = SZ(features);
  99. const int numFeatures = SZ(features[0]);
  100. assert(numData==SZ(answers));
  101. assert(numData>1);
  102.  
  103. TreeNode root; // ルートのノード
  104. root.level = 0;
  105.  
  106. root.bags.resize(numData);
  107. for (int i = 0; i < numData; i++)
  108. {
  109. // ここで、同じIDが選ばれる可能性があるが、問題なし。
  110. root.bags[i] = randxor.random()%numData;
  111. }
  112. m_nodes.emplace_back(root);
  113.  
  114. int curNode = 0;
  115. // m_nodesに子ノードがどんどん追加されていく幅優先探索
  116. while (curNode < SZ(m_nodes))
  117. {
  118. TreeNode &node = m_nodes[curNode];
  119.  
  120. // 現在のノードに入っている目的変数が、すべて同じかどうかを調べる
  121. // (その場合は、ノードを分ける必要がなくなる)
  122. bool equal = true; // すべて同じならtrue
  123. for (int i=1;i<SZ(node.bags);i++)
  124. {
  125. if (answers[node.bags[i]] != answers[node.bags[i - 1]])
  126. {
  127. equal = false;
  128. break;
  129. }
  130. }
  131.  
  132. // 葉になる条件のチェック
  133. if (equal || SZ(node.bags) <= minNodeSize || node.level >= maxLevel)
  134. {
  135. // 葉にして子ノードは増やさない。
  136. setLeaf( node, curNode, answers );
  137. continue;
  138. }
  139.  
  140. // どこで分けるのがベストかを調べる
  141. int bestFeatureID = -1;
  142. int bestLeft=0, bestRight=0;
  143. FeatureType bestValue = 0;
  144. double bestMSE = 1e99; // 平均との2乗の差の和
  145.  
  146. for(int i=0;i<numRandomFeatures;i++)
  147. {
  148. // x0,x1,x2...の、どの軸で分けるかを決める
  149. const int featureID = randxor.random()%numFeatures;
  150. for(int j=0;j<numRandomPositions;j++) // どの位置で分けるか
  151. {
  152. const int positionID = randxor.random()%SZ(node.bags);
  153. const FeatureType splitValue = features[node.bags[positionID]][featureID];
  154. double sumLeft = 0;
  155. double sumRight = 0;
  156. int totalLeft = 0; // splitValue未満の個数
  157. int totalRight = 0; // splitValue以上の個数
  158. for (int p : node.bags)
  159. {
  160. if (features[p][featureID] < splitValue)
  161. {
  162. sumLeft += answers[p];
  163. totalLeft++;
  164. } else {
  165. sumRight += answers[p];
  166. totalRight++;
  167. }
  168. }
  169.  
  170. // nodeBagのデータが"未満"か"以上"のどちらかに全部偏ってるので
  171. // 分け方として意味がないので、すぐやめる。
  172. if (totalLeft == 0 || totalRight == 0)
  173. continue;
  174.  
  175. // 平均との差を求める(回帰用)
  176. double meanLeft = totalLeft == 0 ? 0 : sumLeft / totalLeft;
  177. double meanRight = totalRight == 0 ? 0 : sumRight / totalRight;
  178.  
  179. double mse = 0;
  180. for (int p : node.bags)
  181. {
  182. if (features[p][featureID] < splitValue)
  183. {
  184. mse += (answers[p] - meanLeft) * (answers[p] - meanLeft);
  185. }
  186. else
  187. {
  188. mse += (answers[p] - meanRight) * (answers[p] - meanRight);
  189. }
  190. }
  191.  
  192. if (mse < bestMSE)
  193. {
  194. bestMSE = mse;
  195. bestValue = splitValue;
  196. bestFeatureID = featureID;
  197. bestLeft = totalLeft;
  198. bestRight = totalRight;
  199. }
  200. }
  201. }
  202.  
  203. // 左か右にどちらかに偏るような分け方しかできなかった場合は、葉にする
  204. // (すべての分け方を試すわけではないので、こういうことは起こりえます)
  205. if (bestLeft == 0 || bestRight == 0)
  206. {
  207. setLeaf( node, curNode, answers );
  208. continue;
  209. }
  210.  
  211. // うまく分けれたので、新しい子ノードを2つ追加する
  212. TreeNode left;
  213. TreeNode right;
  214.  
  215. left.level = right.level = node.level + 1;
  216. node.featureID = bestFeatureID;
  217. node.value = bestValue;
  218. node.left = SZ(m_nodes);
  219. node.right = SZ(m_nodes) + 1;
  220.  
  221. left.bags.resize(bestLeft);
  222. right.bags.resize(bestRight);
  223. for (int p : node.bags)
  224. {
  225. if (features[p][node.featureID] < node.value)
  226. {
  227. left.bags[--bestLeft] = p;
  228. }
  229. else
  230. {
  231. right.bags[--bestRight] = p;
  232. }
  233. }
  234.  
  235. m_nodes.emplace_back(left);
  236. m_nodes.emplace_back(right);
  237. curNode++;
  238. }
  239. }
  240.  
  241. // 予測
  242. // features テスト用の説明変数x0,x1,x2のセット
  243. // 返り値 目的変数yの予測値
  244. AnswerType estimate(const vector <FeatureType>& features) const
  245. {
  246. // ルートからたどるだけ
  247. const TreeNode *pNode = &m_nodes[0];
  248. while (true)
  249. {
  250. if (pNode->leaf)
  251. {
  252. break;
  253. }
  254.  
  255. const int nextNodeID = features[pNode->featureID] < pNode->value ? pNode->left : pNode->right;
  256. pNode = &m_nodes[nextNodeID];
  257. }
  258.  
  259. return pNode->answer;
  260. }
  261.  
  262. private:
  263.  
  264. // nodeを葉にして、curNodeを次のノードへ進める
  265. void setLeaf( TreeNode& node, int& curNode, const vector<AnswerType>& answers ) const
  266. {
  267. node.leaf = true;
  268.  
  269. // 回帰の場合は、目的変数yの平均を求める
  270. for (int p : node.bags)
  271. {
  272. node.answer += answers[p];
  273. }
  274.  
  275. assert(SZ(node.bags) > 0);
  276. if (SZ(node.bags))
  277. {
  278. node.answer /= SZ(node.bags);
  279. }
  280. curNode++;
  281. }
  282.  
  283. vector < TreeNode > m_nodes; // 決定木のノードたち。m_nodes[0]がルート
  284. };
  285.  
  286. class RandomForest {
  287. public:
  288. RandomForest()
  289. {
  290. clear();
  291. }
  292.  
  293. void clear()
  294. {
  295. m_trees.clear();
  296. }
  297.  
  298. // 訓練
  299. // 繰り返し呼ぶことで木を追加することもできる。
  300. // features 説明変数x0,x1,x2...のセット
  301. // answers 目的変数yのセット
  302. // treesNo      追加する木の数
  303. // minNodeSize ノード内
  304.  
  305. void train(const vector <vector <FeatureType> >& features,
  306. const vector <AnswerType>& answers,
  307. int treesNo,
  308. int minNodeSize)
  309. {
  310. for(int i=0;i<treesNo;i++)
  311. {
  312. m_trees.emplace_back(DecisionTree(features, answers, minNodeSize, 16, 2, 5));
  313. }
  314. }
  315.  
  316.  
  317. // 回帰の予測
  318. // features テスト用の説明変数x0,x1,x2のセット
  319. // 返り値 目的変数yの予測値
  320. AnswerType estimateRegression(vector <FeatureType> &features)
  321. {
  322. if (SZ(m_trees) == 0)
  323. {
  324. return 0.0;
  325. }
  326.  
  327. // すべての木から得られた結果の平均をとるだけ
  328. double sum = 0;
  329. for(int i=0;i<SZ(m_trees);i++)
  330. {
  331. sum += m_trees[i].estimate(features);
  332. }
  333. return sum / SZ(m_trees);
  334. }
  335.  
  336. private:
  337. vector < DecisionTree > m_trees; // たくさんの決定木
  338. };
  339.  
  340. int main()
  341. {
  342. int numAll; // 全データ数
  343. int numTrainings; // 訓練データ数
  344. int numTests; // テストデータ数
  345. int numFeatures; // 説明変数の数
  346.  
  347. // y = f(x0,x1,x2,...)
  348. // x0,x1,x2は説明変数です。コード上ではfeatureと命名してます。
  349. // yは目的変数です。コード上ではanswerという命名をしてます。
  350.  
  351.  
  352. cin >> numAll >> numTrainings >> numTests >> numFeatures;
  353. assert(numTrainings+numTests<=numAll);
  354.  
  355. // 全データ
  356. vector < vector <FeatureType> > allFeatures(numAll, vector <FeatureType> (numFeatures));
  357. vector < AnswerType > allAnswers(numAll);
  358.  
  359. for(int i = 0 ; i < numAll; ++i)
  360. {
  361. for (int k = 0; k < numFeatures; ++k)
  362. {
  363. cin >> allFeatures[i][k];
  364. }
  365. cin >> allAnswers[i];
  366. }
  367.  
  368. // シャッフル用
  369. vector < int > shuffleTable;
  370. for (int i = 0; i < numTrainings+numTests; ++i)
  371. {
  372. shuffleTable.emplace_back(i);
  373. }
  374. randxor.randomShuffle(shuffleTable);
  375.  
  376. // 訓練データ
  377. vector < vector <FeatureType> > trainingFeatures(numTrainings, vector <FeatureType>(numFeatures));
  378. vector < AnswerType > trainingAnswers(numTrainings);
  379. for (int i = 0; i < numTrainings; ++i)
  380. {
  381. trainingFeatures[i] = allFeatures[shuffleTable[i]];
  382. trainingAnswers[i] = allAnswers[shuffleTable[i]];
  383. }
  384.  
  385. // テストデータ
  386. vector < vector <FeatureType> > testFeatures(numTests, vector <FeatureType>(numFeatures));
  387. vector < AnswerType > testAnswers(numTests);
  388. for (int i = 0; i < numTests; ++i)
  389. {
  390. testFeatures[i] = allFeatures[shuffleTable[numTrainings+i]];
  391. testAnswers[i] = allAnswers[shuffleTable[numTrainings+i]];
  392. }
  393.  
  394. // ランダムフォレストを使って予測
  395. RandomForest* rf = new RandomForest();
  396.  
  397. // 木を徐々に増やしていく
  398. int numTrees = 0;
  399. for (int k = 0; k < 20; ++k)
  400. {
  401. // 学習
  402. const int numAdditionalTrees = 1;
  403. rf->train(trainingFeatures, trainingAnswers, numAdditionalTrees, 1);
  404. numTrees += numAdditionalTrees;
  405.  
  406. // 予測と結果表示
  407. cout << "-----" << endl;
  408. cout << "numTrees=" << numTrees << endl;
  409. double totalError = 0.0;
  410. for (int i = 0; i < numTests; ++i)
  411. {
  412. const double myAnswer = rf->estimateRegression(testFeatures[i]);
  413. const double diff = myAnswer-testAnswers[i];
  414. totalError += abs(diff);
  415. cout << " myAnswer=" << myAnswer << " testAnswer=" << testAnswers[i] << " diff=" << diff << endl;
  416. }
  417. cout << "totalError=" << totalError << endl;
  418. }
  419.  
  420. delete rf;
  421.  
  422. return 0;
  423. }
  424.  
  425.  
Success #stdin #stdout 0s 4340KB
stdin
35 30 5 2
1  118  30
1  484  58
1  664  87
1 1004 115
1 1231 120
1 1372 142
1 1582 145
2  118  33
2  484  69
2  664 111
2 1004 156
2 1231 172
2 1372 203
2 1582 203
3  118  30
3  484  51
3  664  75
3 1004 108
3 1231 115
3 1372 139
3 1582 140
4  118  32
4  484  62
4  664 112
4 1004 167
4 1231 179
4 1372 209
4 1582 214
5  118  30
5  484  49
5  664  81
5 1004 125
5 1231 142
5 1372 174
5 1582 177
stdout
-----
numTrees=1
 myAnswer=32 testAnswer=30 diff=2
 myAnswer=111.8 testAnswer=75 diff=36.8
 myAnswer=115 testAnswer=167 diff=-52
 myAnswer=139 testAnswer=140 diff=-1
 myAnswer=142 testAnswer=145 diff=-3
totalError=94.8
-----
numTrees=2
 myAnswer=32 testAnswer=30 diff=2
 myAnswer=111.4 testAnswer=75 diff=36.4
 myAnswer=113.5 testAnswer=167 diff=-53.5
 myAnswer=158 testAnswer=140 diff=18
 myAnswer=159.5 testAnswer=145 diff=14.5
totalError=124.4
-----
numTrees=3
 myAnswer=32 testAnswer=30 diff=2
 myAnswer=111.6 testAnswer=75 diff=36.6
 myAnswer=111.667 testAnswer=167 diff=-55.3333
 myAnswer=176.667 testAnswer=140 diff=36.6667
 myAnswer=177.667 testAnswer=145 diff=32.6667
totalError=163.267
-----
numTrees=4
 myAnswer=31.5 testAnswer=30 diff=1.5
 myAnswer=111.7 testAnswer=75 diff=36.7
 myAnswer=122.75 testAnswer=167 diff=-44.25
 myAnswer=167.25 testAnswer=140 diff=27.25
 myAnswer=168.75 testAnswer=145 diff=23.75
totalError=133.45
-----
numTrees=5
 myAnswer=35 testAnswer=30 diff=5
 myAnswer=105.76 testAnswer=75 diff=30.76
 myAnswer=114.6 testAnswer=167 diff=-52.4
 myAnswer=161.6 testAnswer=140 diff=21.6
 myAnswer=159 testAnswer=145 diff=14
totalError=123.76
-----
numTrees=6
 myAnswer=34.5 testAnswer=30 diff=4.5
 myAnswer=106.133 testAnswer=75 diff=31.1333
 myAnswer=113.5 testAnswer=167 diff=-53.5
 myAnswer=167.25 testAnswer=140 diff=27.25
 myAnswer=152.5 testAnswer=145 diff=7.5
totalError=123.883
-----
numTrees=7
 myAnswer=49.9524 testAnswer=30 diff=19.9524
 myAnswer=98.2571 testAnswer=75 diff=23.2571
 myAnswer=117.667 testAnswer=167 diff=-49.3333
 myAnswer=163.214 testAnswer=140 diff=23.2143
 myAnswer=140.657 testAnswer=145 diff=-4.34286
totalError=120.1
-----
numTrees=8
 myAnswer=49.8333 testAnswer=30 diff=19.8333
 myAnswer=89.725 testAnswer=75 diff=14.725
 myAnswer=119.812 testAnswer=167 diff=-47.1875
 myAnswer=169.562 testAnswer=140 diff=29.5625
 myAnswer=140.825 testAnswer=145 diff=-4.175
totalError=115.483
-----
numTrees=9
 myAnswer=47.8519 testAnswer=30 diff=17.8519
 myAnswer=92.2 testAnswer=75 diff=17.2
 myAnswer=123.833 testAnswer=167 diff=-43.1667
 myAnswer=170.389 testAnswer=140 diff=30.3889
 myAnswer=144.844 testAnswer=145 diff=-0.155556
totalError=108.763
-----
numTrees=10
 myAnswer=46.1667 testAnswer=30 diff=16.1667
 myAnswer=92.355 testAnswer=75 diff=17.355
 myAnswer=127.45 testAnswer=167 diff=-39.55
 myAnswer=174.75 testAnswer=140 diff=34.75
 myAnswer=150.66 testAnswer=145 diff=5.66
totalError=113.482
-----
numTrees=11
 myAnswer=44.8788 testAnswer=30 diff=14.8788
 myAnswer=94.1409 testAnswer=75 diff=19.1409
 myAnswer=132.136 testAnswer=167 diff=-34.8636
 myAnswer=177.318 testAnswer=140 diff=37.3182
 myAnswer=155.418 testAnswer=145 diff=10.4182
totalError=116.62
-----
numTrees=12
 myAnswer=43.8056 testAnswer=30 diff=13.8056
 myAnswer=95.6292 testAnswer=75 diff=20.6292
 myAnswer=130.458 testAnswer=167 diff=-36.5417
 myAnswer=172.125 testAnswer=140 diff=32.125
 myAnswer=153.078 testAnswer=145 diff=8.07778
totalError=111.179
-----
numTrees=13
 myAnswer=42.7897 testAnswer=30 diff=12.7897
 myAnswer=96.0423 testAnswer=75 diff=21.0423
 myAnswer=134.192 testAnswer=167 diff=-32.8077
 myAnswer=169.577 testAnswer=140 diff=29.5769
 myAnswer=152.226 testAnswer=145 diff=7.22564
totalError=103.442
-----
numTrees=14
 myAnswer=41.8762 testAnswer=30 diff=11.8762
 myAnswer=96.7 testAnswer=75 diff=21.7
 myAnswer=132.929 testAnswer=167 diff=-34.0714
 myAnswer=172.357 testAnswer=140 diff=32.3571
 myAnswer=150.448 testAnswer=145 diff=5.44762
totalError=105.452
-----
numTrees=15
 myAnswer=41.2178 testAnswer=30 diff=11.2178
 myAnswer=97.72 testAnswer=75 diff=22.72
 myAnswer=131.533 testAnswer=167 diff=-35.4667
 myAnswer=170.133 testAnswer=140 diff=30.1333
 myAnswer=146.644 testAnswer=145 diff=1.64444
totalError=101.182
-----
numTrees=16
 myAnswer=40.5167 testAnswer=30 diff=10.5167
 myAnswer=98.0917 testAnswer=75 diff=23.0917
 myAnswer=130.208 testAnswer=167 diff=-36.7917
 myAnswer=172.875 testAnswer=140 diff=32.875
 myAnswer=146.354 testAnswer=145 diff=1.35417
totalError=104.629
-----
numTrees=17
 myAnswer=39.898 testAnswer=30 diff=9.89804
 myAnswer=98.851 testAnswer=75 diff=23.851
 myAnswer=131.055 testAnswer=167 diff=-35.9451
 myAnswer=173.627 testAnswer=140 diff=33.6275
 myAnswer=148.667 testAnswer=145 diff=3.66667
totalError=106.988
-----
numTrees=18
 myAnswer=39.4593 testAnswer=30 diff=9.45926
 myAnswer=97.8593 testAnswer=75 diff=22.8593
 myAnswer=130.719 testAnswer=167 diff=-36.2815
 myAnswer=170.37 testAnswer=140 diff=30.3704
 myAnswer=147.074 testAnswer=145 diff=2.07407
totalError=101.044
-----
numTrees=19
 myAnswer=41.6456 testAnswer=30 diff=11.6456
 myAnswer=94.2877 testAnswer=75 diff=19.2877
 myAnswer=129.523 testAnswer=167 diff=-37.4772
 myAnswer=171.246 testAnswer=140 diff=31.2456
 myAnswer=146.807 testAnswer=145 diff=1.80702
totalError=101.463
-----
numTrees=20
 myAnswer=45.8133 testAnswer=30 diff=15.8133
 myAnswer=94.7233 testAnswer=75 diff=19.7233
 myAnswer=128.647 testAnswer=167 diff=-38.3533
 myAnswer=169.633 testAnswer=140 diff=29.6333
 myAnswer=149.617 testAnswer=145 diff=4.61667
totalError=108.14