fork download
  1. // ランダムフォレスト回帰(random forest regression)by 診断人
  2. // データはhttp://c...content-available-to-author-only...e.com/mathematics/brase/understandable_statistics/7e/students/datasets/mlr/frames/frame.htmlのDenver Neighborhoodsを使用しています。
  3.  
  4. #include <cstdio>
  5. #include <cstdlib>
  6. #include <iostream>
  7. #include <vector>
  8. #include <algorithm>
  9. #include <assert.h>
  10.  
  11. #define INCMSE (1) // 重要度INCMSEを使用するか
  12. #define INCNODEPURITY (1) // 重要度INCNODEPURITYを使用するか
  13.  
  14. using namespace std;
  15. #define SZ(a) ((int)(a).size())
  16.  
  17.  
  18. // 乱数は、xorを使ってますが、メルセンヌツイスターの方がよいかも知れません。
  19. class RandXor
  20. {
  21. public:
  22. RandXor()
  23. {
  24. init();
  25. }
  26.  
  27. void init()
  28. {
  29. x=123456789;
  30. y=362436069;
  31. z=521288629;
  32. w= 88675123;
  33. }
  34.  
  35. inline unsigned int random()
  36. {
  37. unsigned int t;
  38. t=(x^(x<<11));x=y;y=z;z=w; return( w=(w^(w>>19))^(t^(t>>8)) );
  39. }
  40. private:
  41. unsigned int x;
  42. unsigned int y;
  43. unsigned int z;
  44. unsigned int w;
  45. };
  46.  
  47. static RandXor randxor; // マルチスレッド対応にするなら、木ごとに乱数用オブジェクトを用意して、シードを変えましょう。
  48.  
  49.  
  50. typedef double FeatureType;
  51. typedef double AnswerType;
  52.  
  53. struct TreeNode {
  54. bool leaf; // 葉(=子がない)ならtrue
  55. int level; // ノードの深さ。ルートノードは0
  56. int featureID; // 説明変数ID。x0, x1, x2... の0,1,2の部分
  57. FeatureType value; // 分割する値
  58. AnswerType answer; // ノード内(=領域内)の目的変数yの平均値
  59. vector <int> bags; // ノード内(=領域内)に含まれるデータのID
  60. int left; // 左側の子のノードID
  61. int right; // 右側の子のノードID
  62. #if INCNODEPURITY
  63. double nodePurity; // ノードのPurity。平均2乗誤差が入る
  64. #endif // INCNODEPURITY
  65.  
  66.  
  67. TreeNode() {
  68. leaf = false;
  69. level = -1;
  70. featureID = -1;
  71. value = 0;
  72. answer = 0;
  73. left = -1;
  74. right = -1;
  75. #if INCNODEPURITY
  76. nodePurity = 0;
  77. #endif // INCNODEPURITY
  78. }
  79. };
  80.  
  81. class DecisionTree
  82. {
  83. public:
  84. DecisionTree() { }
  85.  
  86. // 学習。訓練データをいれて、決定木を作成する。
  87. // features 説明変数(特徴量)x0,x1,x2...
  88. // answers 目的変数y
  89. // nodeSize ノード内に入るデータの数。これを下回れば計算は打ち切られ葉になる。Rの実装では、回帰のときは5、分類のときは1が使われている。
  90. // maxLevel ノードの深さの最大値
  91. // numRandomFeatures 領域を分けるときに試す説明変数(グラフでは軸)の数。Rの実装ではmtryという引数で、回帰のときは(説明変数の数)/3、分類のときは(説明変数の数)^0.5が使われている。
  92. // numRandomPositions 領域を分けるときに試すデータ(グラフでは点)の数
  93. // rootBagSizeRatio 全訓練データの中から、どれだけの割合をルートに入れるか(ブートストラップサンプリング)
  94. DecisionTree(const vector <vector <FeatureType> >& features,
  95. const vector <AnswerType>& answers,
  96. int nodeSize,
  97. int maxLevel,
  98. int numRandomFeatures,
  99. int numRandomPositions,
  100. float rootBagSizeRatio)
  101. {
  102. const int numData = SZ(features); // データの数
  103. const int numFeatures = SZ(features[0]); // 説明変数の数
  104. assert(numData == SZ(answers));
  105. assert(numData>1);
  106.  
  107. #if INCNODEPURITY
  108. m_incNodePurity.clear();
  109. m_incNodePurity.resize(numFeatures);
  110. #endif // INCNODEPURITY
  111.  
  112. TreeNode root; // ルートのノード
  113. root.level = 0;
  114.  
  115. const int rootBagSize = static_cast<int>(numData * rootBagSizeRatio); // ルートに入るデータの数
  116. root.bags.resize(rootBagSize);
  117.  
  118. #if INCMSE
  119. vector <int> freq(numData); // データ番号別の、ルートに選ばれた個数
  120. #endif // INCMSE
  121.  
  122. // ブートストラップサンプリング
  123. for (int i = 0; i < rootBagSize; i++)
  124. {
  125. // ここで、同じIDが選ばれる可能性があるが、問題なし。
  126. int row = randxor.random() % numData;
  127. root.bags[i] = row;
  128. #if INCMSE
  129. freq[row]++;
  130. #endif // INCMSE
  131. }
  132.  
  133. #if INCMSE
  134. // 選ばれなかった
  135. m_oob.clear();
  136. for (int i = 0; i < SZ(freq); ++i)
  137. {
  138. if(freq[i]==0)
  139. {
  140. m_oob.emplace_back(i);
  141. }
  142. }
  143. #endif // INCMSE
  144.  
  145. #if INCNODEPURITY
  146. double sumAll = 0; // ルートのyの総和
  147. for (int p : root.bags)
  148. {
  149. sumAll += answers[p];
  150. }
  151.  
  152. const double meanAll = sumAll / SZ(root.bags); // ルートのyの平均
  153. // ルートの平均2乗誤差
  154. root.nodePurity = 0;
  155. for (int p : root.bags)
  156. {
  157. root.nodePurity += lossFunction(answers[p], meanAll);
  158. }
  159. #endif // INCNODEPURITY
  160.  
  161. // 決定木をつくる。m_nodesに子ノードがどんどん追加されていく幅優先探索。
  162. m_nodes.emplace_back(root);
  163. int curNode = 0; // 現在の決定木のノード番号
  164. while (curNode < SZ(m_nodes))
  165. {
  166. TreeNode &node = m_nodes[curNode];
  167.  
  168. // 現在のノードに入っている目的変数が、すべて同じかどうかを調べる
  169. // (その場合は、ノードを分ける必要がなくなる)
  170. bool equal = true; // すべて同じならtrue
  171. for (int i = 1; i<SZ(node.bags); i++)
  172. {
  173. if (answers[node.bags[i]] != answers[node.bags[i - 1]])
  174. {
  175. equal = false;
  176. break;
  177. }
  178. }
  179.  
  180. // 葉になる条件のチェック
  181. if (equal || SZ(node.bags) <= nodeSize || node.level >= maxLevel)
  182. {
  183. // 葉にして子ノードは増やさない。
  184. setLeaf(node, curNode, answers);
  185. continue;
  186. }
  187.  
  188. // どこで分けるのがベストかを調べる
  189. int bestFeatureID = -1;
  190. int bestLeft = 0, bestRight = 0;
  191. FeatureType bestValue = 0;
  192. double bestMSE = 1e99; // 平均2乗誤差(平均との差の2乗の和)
  193. #if INCNODEPURITY
  194. double bestMSELeft = 1e99; // splitValue未満の平均2乗誤差
  195. double bestMSERight = 1e99; // splitValue以上の平均2乗誤差
  196. #endif // INCNODEPURITY
  197.  
  198. for (int i = 0; i<numRandomFeatures; i++)
  199. {
  200. // x0,x1,x2...の、どの軸で分けるかを決める
  201. int featureID = randxor.random() % numFeatures;
  202.  
  203.  
  204. for (int j = 0; j<numRandomPositions; j++) // どの位置で分けるか
  205. {
  206. const int positionID = randxor.random() % SZ(node.bags);
  207. const FeatureType splitValue = features[node.bags[positionID]][featureID];
  208. double sumLeft = 0; // splitValue未満のyの総和
  209. double sumRight = 0; // splitValue以上のyの総和
  210. int totalLeft = 0; // splitValue未満の個数
  211. int totalRight = 0; // splitValue以上の個数
  212. for (int p : node.bags)
  213. {
  214. if (features[p][featureID] < splitValue)
  215. {
  216. sumLeft += answers[p];
  217. totalLeft++;
  218. }
  219. else
  220. {
  221. sumRight += answers[p];
  222. totalRight++;
  223. }
  224. }
  225.  
  226. // nodeBagのデータが"未満"か"以上"のどちらかに全部偏ってるので
  227. // 分け方として意味がないので、すぐやめる。
  228. if (totalLeft == 0 || totalRight == 0)
  229. continue;
  230.  
  231. // 平均との差を使って、平均二乗誤差(MSE)を求める
  232. double meanLeft = totalLeft == 0 ? 0 : sumLeft / totalLeft;
  233. double meanRight = totalRight == 0 ? 0 : sumRight / totalRight;
  234.  
  235. double mseLeft = 0; // 平均二乗誤差(MSE)
  236. double mseRight = 0; // 平均二乗誤差(MSE)
  237. for (int p : node.bags)
  238. {
  239. if (features[p][featureID] < splitValue)
  240. {
  241. mseLeft += lossFunction(answers[p], meanLeft);
  242. }
  243. else
  244. {
  245. mseRight += lossFunction(answers[p], meanRight);
  246. }
  247. }
  248.  
  249. if (mseLeft + mseRight < bestMSE)
  250. {
  251. bestMSE = mseLeft + mseRight;
  252. #if INCNODEPURITY
  253. bestMSELeft = mseLeft;
  254. bestMSERight = mseRight;
  255. #endif // INCNODEPURITY
  256. bestValue = splitValue;
  257. bestFeatureID = featureID;
  258. bestLeft = totalLeft;
  259. bestRight = totalRight;
  260. }
  261. }
  262. }
  263.  
  264. // 左か右にどちらかに偏るような分け方しかできなかった場合は、葉にする
  265. // (すべての分け方を試すわけではないので、こういうことは起こりえます)
  266. if (bestLeft == 0 || bestRight == 0)
  267. {
  268. setLeaf(node, curNode, answers);
  269. continue;
  270. }
  271.  
  272. // うまく分けれたので、新しい子ノードを2つ追加する
  273. TreeNode left;
  274. TreeNode right;
  275.  
  276. left.level = right.level = node.level + 1;
  277. node.featureID = bestFeatureID;
  278. node.value = bestValue;
  279. node.left = SZ(m_nodes);
  280. node.right = SZ(m_nodes) + 1;
  281.  
  282. left.bags.resize(bestLeft);
  283. right.bags.resize(bestRight);
  284. for (int p : node.bags)
  285. {
  286. if (features[p][node.featureID] < node.value)
  287. {
  288. left.bags[--bestLeft] = p;
  289. }
  290. else
  291. {
  292. right.bags[--bestRight] = p;
  293. }
  294. }
  295.  
  296. #if INCNODEPURITY
  297. left.nodePurity = bestMSELeft;
  298. right.nodePurity = bestMSERight;
  299. m_incNodePurity[bestFeatureID] += node.nodePurity - left.nodePurity - right.nodePurity;
  300. #endif // INCNODEPURITY
  301. m_nodes.emplace_back(left);
  302. m_nodes.emplace_back(right);
  303. curNode++;
  304. }
  305. }
  306.  
  307. // 平均との差の2乗を求める
  308. // y 値
  309. // mean 平均
  310. // 返り値 平均との差の2乗
  311. double lossFunction(double y, double mean) const
  312. {
  313. return (y - mean)*(y - mean);
  314. }
  315.  
  316. // 予測
  317. // features テスト用の説明変数x0,x1,x2のセット
  318. // 返り値 目的変数yの予測値
  319. AnswerType estimate(const vector <FeatureType>& features) const
  320. {
  321. // ルートからたどるだけ
  322. const TreeNode *pNode = &m_nodes[0];
  323. while (true)
  324. {
  325. if (pNode->leaf)
  326. {
  327. break;
  328. }
  329.  
  330. const int nextNodeID = features[pNode->featureID] < pNode->value ? pNode->left : pNode->right;
  331. pNode = &m_nodes[nextNodeID];
  332. }
  333.  
  334. return pNode->answer;
  335. }
  336.  
  337. // bagsをクリアする。ランダムフォレストで一番メモリを使うのはbagsで、
  338. // 訓練してしまえばいらないので、メモリ不足のときはクリアしたほうがよい。
  339. void clearBags()
  340. {
  341. for (auto& node : m_nodes)
  342. {
  343. node.bags.clear();
  344. node.bags.shrink_to_fit();
  345. }
  346. }
  347.  
  348. // OOBデータ
  349. const vector <int>& getOOB() const
  350. {
  351. return m_oob;
  352. }
  353.  
  354. #if INCNODEPURITY
  355. const vector < double >& getIncNodePurity() const
  356. {
  357. return m_incNodePurity;
  358. }
  359. #endif // INCNODEPURITY
  360.  
  361. private:
  362.  
  363. // nodeを葉にして、curNodeを次のノードへ進める
  364. void setLeaf(TreeNode& node, int& curNode, const vector<AnswerType>& answers) const
  365. {
  366. node.leaf = true;
  367.  
  368. // 回帰の場合は、目的変数yの平均を求める
  369. for (int p : node.bags)
  370. {
  371. node.answer += answers[p];
  372. }
  373.  
  374. assert(SZ(node.bags) > 0);
  375. if (SZ(node.bags))
  376. {
  377. node.answer /= SZ(node.bags);
  378. }
  379. curNode++;
  380. }
  381.  
  382. vector < TreeNode > m_nodes; // 決定木のノードたち。m_nodes[0]がルート
  383. vector < int > m_oob; // OOBデータ(OOB Examples。訓練データから外れたもの)
  384. #if INCNODEPURITY
  385. vector < double > m_incNodePurity;
  386. #endif // INCNODEPURITY
  387. };
  388.  
  389. // ランダムフォレスト回帰
  390. class RandomForest {
  391. public:
  392. RandomForest()
  393. {
  394. clear();
  395. }
  396.  
  397. // 全部の決定木をクリア
  398. void clear()
  399. {
  400. m_trees.clear();
  401. }
  402.  
  403. // 学習。訓練データをいれて、決定木を作成する。
  404. // 引数は、DecisionTreeにそのまま渡しているだけなので、そちらのコメントを参照してください。
  405. // numAdditionalTrees 追加する木の数
  406. void train(const vector <vector <FeatureType> >& features,
  407. const vector <AnswerType>& answers,
  408. int nodeSize,
  409. int maxLevel,
  410. int numRandomFeatures,
  411. int numRandomPositions,
  412. float rootBagSizeRatio,
  413. int numAdditionalTrees)
  414. {
  415. for (int i = 0; i < numAdditionalTrees; i++)
  416. {
  417. m_trees.emplace_back(DecisionTree(features, answers, nodeSize, maxLevel, numRandomFeatures, numRandomPositions, rootBagSizeRatio));
  418. }
  419. }
  420.  
  421.  
  422. // 予測
  423. // features テスト用の説明変数x0,x1,x2のセット
  424. // 返り値 目的変数yの予測値
  425. AnswerType estimateRegression(const vector <FeatureType> &testFeatures) const
  426. {
  427. if (SZ(m_trees) == 0)
  428. {
  429. return 0.0;
  430. }
  431.  
  432. // すべての木から得られた結果の平均をとるだけ
  433. FeatureType sum = 0;
  434. for (int i = 0; i < SZ(m_trees); i++)
  435. {
  436. sum += m_trees[i].estimate(testFeatures);
  437. }
  438. return sum / SZ(m_trees);
  439. }
  440.  
  441. // 決定木のbagsをクリア(メモリ節約用)
  442. // treeID 決定木のID
  443. void clearBags(int treeID)
  444. {
  445. m_trees[treeID].clearBags();
  446. }
  447.  
  448.  
  449. #if INCMSE
  450.  
  451. // OOBエラーを求める
  452. // trainingFeatures 訓練データの説明変数(特徴量)x0,x1,x2...
  453. // trainingAnswers 訓練データの目的変数y
  454. // 返り値 OOBエラー
  455. double calculateOOBErr(const vector <vector <FeatureType> >& trainingFeatures, const vector < AnswerType >& trainingAnswers) const
  456. {
  457. const int numFeatures = SZ(trainingFeatures[0]);
  458. double averageOOBErr = 0.0;
  459.  
  460. for (int i = 0; i < SZ(m_trees); i++)
  461. {
  462. const DecisionTree& tree = m_trees[i];
  463. double ooberr = 0.0;
  464.  
  465. // OOBデータを使って、予測。2乗誤差の総和をooberrに入れる。
  466. const vector <int>& oob = tree.getOOB();
  467. for (int x = 0; x < SZ(oob); ++x)
  468. {
  469. const double myAnswer = tree.estimate(trainingFeatures[oob[x]]);
  470.  
  471. ooberr += (trainingAnswers[oob[x]] - myAnswer)*(trainingAnswers[oob[x]] - myAnswer);
  472. }
  473.  
  474. if(SZ(oob)>0)
  475. {
  476. ooberr /= SZ(oob);
  477. }
  478.  
  479. averageOOBErr += ooberr;
  480. }
  481.  
  482. averageOOBErr /= SZ(m_trees);
  483.  
  484. return averageOOBErr;
  485. }
  486.  
  487. // IncMSEを求める
  488. // trainingFeatures 訓練データの説明変数(特徴量)x0,x1,x2...
  489. // trainingAnswers 訓練データの目的変数y
  490. // nPerm 特徴量1つあたりのシャッフル予測の回数
  491. // 返り値 特徴量ごとのIncMSE
  492. vector <double> calculateIncMSE(const vector <vector <FeatureType> >& trainingFeatures, const vector < AnswerType >& trainingAnswers, int nPerm) const
  493. {
  494. const int numFeatures = SZ(trainingFeatures[0]);
  495. vector <double> errimp(numFeatures);
  496.  
  497. for (int i = 0; i < SZ(m_trees); i++)
  498. {
  499. const DecisionTree& tree = m_trees[i];
  500. double ooberr = 0.0;
  501.  
  502. // OOBデータを使って、予測。2乗誤差の総和をooberrに入れる。
  503. const vector <int>& oob = tree.getOOB();
  504. for (int x = 0; x < SZ(oob); ++x)
  505. {
  506. const double myAnswer = tree.estimate(trainingFeatures[oob[x]]);
  507. ooberr += (trainingAnswers[oob[x]] - myAnswer)*(trainingAnswers[oob[x]] - myAnswer);
  508. }
  509.  
  510. srand(i);
  511. for (int featureID = 0; featureID < numFeatures; ++featureID) // すべての特徴量featureIDでループ
  512. {
  513. double ooberrperm = 0.0;
  514.  
  515. // OOBデータを、featureID番目の特徴量だけをシャッフル
  516. vector <FeatureType> swapped(SZ(oob));
  517. for (int x = 0; x < SZ(oob); ++x)
  518. {
  519. swapped[x] = trainingFeatures[oob[x]][featureID];
  520. }
  521.  
  522. for (int nploop = 0; nploop < nPerm; ++nploop) // 何回シャッフルしたのを試すか。
  523. {
  524. random_shuffle(swapped.begin(), swapped.end());
  525.  
  526. // すべてのOOBデータでシャッフルしたものを使う
  527. for (int x = 0; x < SZ(oob); ++x)
  528. {
  529. vector <FeatureType> vf(trainingFeatures[oob[x]]);
  530. vf[featureID] = swapped[x]; // featureID番目の要素だけ差し替え
  531.  
  532. const double myAnswer = tree.estimate(vf);
  533. ooberrperm += (trainingAnswers[oob[x]] - myAnswer)*(trainingAnswers[oob[x]] - myAnswer);
  534. }
  535. }
  536.  
  537. const double delta = (ooberrperm / nPerm - ooberr) / SZ(oob);
  538. errimp[featureID] += delta;
  539. }
  540. }
  541.  
  542. for (int featureID = 0; featureID < numFeatures; ++featureID)
  543. {
  544. errimp[featureID] = errimp[featureID] / SZ(m_trees); // 平均をとる
  545. }
  546.  
  547. return errimp;
  548. }
  549. #endif
  550.  
  551. #if INCNODEPURITY
  552. // IncNodePurityを求める
  553. // 返り値 特徴量ごとのIncMSE
  554. vector <double> calculateIncNodePurity() const
  555. {
  556. vector < double > retIncNodePurity;
  557.  
  558. if (SZ(m_trees) > 0)
  559. {
  560. const int numFeatures = SZ(m_trees[0].getIncNodePurity());
  561. retIncNodePurity.resize(numFeatures);
  562. for (int i = 0; i < SZ(m_trees); i++)
  563. {
  564. const DecisionTree& tree = m_trees[i];
  565. const vector <double>& purity = tree.getIncNodePurity();
  566. for (int featureID = 0; featureID < numFeatures; ++featureID)
  567. {
  568. retIncNodePurity[featureID] += purity[featureID];
  569. }
  570. }
  571.  
  572. for (int featureID = 0; featureID < numFeatures; ++featureID)
  573. {
  574. retIncNodePurity[featureID] /= SZ(m_trees);
  575. }
  576. }
  577.  
  578. return retIncNodePurity;
  579. }
  580. #endif // INCNODEPURITY
  581.  
  582. private:
  583. vector < DecisionTree > m_trees; // たくさんの決定木
  584. };
  585.  
  586. int main()
  587. {
  588. int numAll; // 全データ数
  589. int numTrainings; // 訓練データ数
  590. int numTests; // テストデータ数
  591. int numFeatures; // 説明変数の数
  592.  
  593. // y = f(x0,x1,x2,...)
  594. // x0,x1,x2は説明変数です。コード上ではfeatureと命名してます。
  595. // yは目的変数です。コード上ではanswerという命名をしてます。
  596.  
  597.  
  598. cin >> numAll >> numTrainings >> numTests >> numFeatures;
  599. assert(numTrainings+numTests<=numAll);
  600.  
  601. // 全データ
  602. vector < vector <FeatureType> > allFeatures(numAll, vector <FeatureType> (numFeatures));
  603. vector < AnswerType > allAnswers(numAll);
  604.  
  605. for(int i = 0 ; i < numAll; ++i)
  606. {
  607. for (int k = 0; k < numFeatures; ++k)
  608. {
  609. cin >> allFeatures[i][k];
  610. }
  611. cin >> allAnswers[i];
  612. }
  613.  
  614. // シャッフル用
  615. vector < int > shuffleTable;
  616. for (int i = 0; i < numTrainings+numTests; ++i)
  617. {
  618. shuffleTable.emplace_back(i);
  619. }
  620. random_shuffle(shuffleTable.begin(), shuffleTable.end());
  621.  
  622. // 訓練データ
  623. vector < vector <FeatureType> > trainingFeatures(numTrainings, vector <FeatureType>(numFeatures));
  624. vector < AnswerType > trainingAnswers(numTrainings);
  625. for (int i = 0; i < numTrainings; ++i)
  626. {
  627. trainingFeatures[i] = allFeatures[shuffleTable[i]];
  628. trainingAnswers[i] = allAnswers[shuffleTable[i]];
  629. }
  630.  
  631. // テストデータ
  632. vector < vector <FeatureType> > testFeatures(numTests, vector <FeatureType>(numFeatures));
  633. vector < AnswerType > testAnswers(numTests);
  634. for (int i = 0; i < numTests; ++i)
  635. {
  636. testFeatures[i] = allFeatures[shuffleTable[numTrainings+i]];
  637. testAnswers[i] = allAnswers[shuffleTable[numTrainings+i]];
  638. }
  639.  
  640. // ランダムフォレストを使って予測
  641. RandomForest* rf = new RandomForest();
  642.  
  643. // 学習
  644. const int numTrees = 100;
  645. rf->train(trainingFeatures, trainingAnswers, 5, 999, 2, 5, 0.66f, numTrees);
  646.  
  647. // 予測と結果表示
  648. printf("-----\n");
  649.  
  650. #if INCMSE
  651. {
  652. vector <double> vd = rf->calculateIncMSE(trainingFeatures, trainingAnswers, 5);
  653. for (int i = 0; i < SZ(vd); ++i)
  654. {
  655. printf("IncMSE x%d=%8.2f\n",i,vd[i]);
  656. }
  657. }
  658. #endif // INCMSE
  659.  
  660. #if INCNODEPURITY
  661. {
  662. vector <double> vd = rf->calculateIncNodePurity();
  663. for (int i = 0; i < SZ(vd); ++i)
  664. {
  665. printf("IncNodePurity x%d=%8.2f\n",i,vd[i]);
  666. }
  667. }
  668. #endif // INCNODEPURITY
  669.  
  670. double totalError = 0.0;
  671. for (int i = 0; i < numTests; ++i)
  672. {
  673. const double myAnswer = rf->estimateRegression(testFeatures[i]);
  674. const double diff = myAnswer-testAnswers[i];
  675. totalError += abs(diff);
  676. printf("Test%3d myAnswer=%8.2f testAnswer=%8.2f diff=%8.2f\n", i, myAnswer, testAnswers[i], diff);
  677. }
  678. printf("totalError=%8.2f\n",totalError);
  679.  
  680.  
  681. delete rf;
  682.  
  683. return 0;
  684. }
  685.  
  686.  
Success #stdin #stdout 0.01s 3296KB
stdin
44 30 14 6
6.9	1.8	30.2	58.3	27.3	84.9	-14.2
8.4	28.5	38.8	87.5	39.8	172.6	-34.1
5.7	7.8	31.7	83.5	26	154.2	-15.8
7.4	2.3	24.2	14.2	29.4	35.2	-13.9
8.5	-0.7	28.1	46.7	26.6	69.2	-13.9
13.8	7.2	10.4	57.9	26.2	111	-22.6
1.7	32.2	7.5	73.8	50.5	704.1	-40.9
3.6	7.4	30	61.3	26.4	69.9	4
8.2	10.2	12.1	41	11.7	65.4	-32.5
5	10.5	13.6	17.4	14.7	132.1	-8.1
2.1	0.3	18.3	34.4	24.2	179.9	12.3
4.2	8.1	21.3	64.9	21.7	139.9	-35
3.9	2	33.1	82	26.3	108.7	-2
4.1	10.8	38.3	83.3	32.6	123.2	-2.2
4.2	1.9	36.9	61.8	21.6	104.7	-14.2
9.4	-1.5	22.4	22.2	33.5	61.5	-32.7
3.6	-0.3	19.6	8.6	27	68.2	-13.4
7.6	5.5	29.1	62.8	32.2	96.9	-8.7
8.5	4.8	32.8	86.2	16	258	0.5
7.5	2.3	26.5	18.7	23.7	32	-0.6
4.1	17.3	41.5	78.6	23.5	127	-12.5
4.6	68.6	39	14.6	38.2	27.1	45.4
7.2	3	20.2	41.4	27.6	70.7	-38.2
13.4	7.1	20.4	13.9	22.5	38.3	-33.6
10.3	1.4	29.8	43.7	29.4	54	-10
9.4	4.6	36	78.2	29.9	101.5	-14.6
2.5	-3.3	37.6	88.5	27.5	185.9	-7.6
10.3	-0.5	31.8	57.2	27.2	61.2	-17.6
7.5	22.3	28.6	5.7	31.3	38.6	27.2
18.7	6.2	39.7	55.8	28.7	52.6	-2.9
5.1	-2	23.8	29	29.3	62.6	-10.3
3.7	19.6	12.3	77.3	32	207.7	-45.6
10.3	3	31.1	51.7	26.2	42.4	-31.9
7.3	19.2	32.9	68.1	25.2	105.2	-35.7
4.2	7	22.1	41.2	21.4	68.6	-8.8
2.1	5.4	27.1	60	23.5	157.3	6.2
2.5	2.8	20.3	29.8	24.1	58.5	-27.5
8.1	8.5	30	66.4	26	63.1	-37.4
10.3	-1.9	15.9	39.9	38.5	86.4	-13.5
10.5	2.8	36.4	72.3	26	77.5	-21.6
5.8	2	24.2	19.5	28.3	63.5	2.2
6.9	2.9	20.7	6.6	25.8	68.9	-2.4
9.3	4.9	34.9	82.4	18.4	102.8	-12
11.4	2.6	38.7	78.2	18.4	86.6	-12.8
stdout
-----
IncMSE        x0=   -6.46
IncMSE        x1=    5.80
IncMSE        x2=  -12.16
IncMSE        x3=   66.87
IncMSE        x4=  -15.22
IncMSE        x5=    0.69
IncNodePurity x0=  599.36
IncNodePurity x1=  958.41
IncNodePurity x2=  890.15
IncNodePurity x3= 1276.43
IncNodePurity x4=  805.89
IncNodePurity x5= 1258.49
Test  0 myAnswer=  -11.36 testAnswer=  -10.30 diff=   -1.06
Test  1 myAnswer=  -15.27 testAnswer=  -13.50 diff=   -1.77
Test  2 myAnswer=  -14.47 testAnswer=   -2.20 diff=  -12.27
Test  3 myAnswer=  -15.27 testAnswer=  -37.40 diff=   22.13
Test  4 myAnswer=    0.57 testAnswer=  -13.90 diff=   14.47
Test  5 myAnswer=   -5.69 testAnswer=  -13.40 diff=    7.71
Test  6 myAnswer=   -6.41 testAnswer=   -2.90 diff=   -3.51
Test  7 myAnswer=  -14.62 testAnswer=  -12.80 diff=   -1.82
Test  8 myAnswer=   -8.04 testAnswer=   -8.10 diff=    0.06
Test  9 myAnswer=  -11.27 testAnswer=  -38.20 diff=   26.93
Test 10 myAnswer=  -14.22 testAnswer=  -32.50 diff=   18.28
Test 11 myAnswer=   -4.54 testAnswer=  -33.60 diff=   29.06
Test 12 myAnswer=   -9.10 testAnswer=  -12.50 diff=    3.40
Test 13 myAnswer=  -11.43 testAnswer=   -8.80 diff=   -2.63
totalError=  145.10