fork(4) download
  1. tree = [{'left':-1,'right':-1,'poison':0,'sets':[]}]
  2.  
  3. mingain = 0.4
  4.  
  5. def entropy(sets):
  6. if len(sets)==0:return 0.0
  7. from math import log
  8. log2 = lambda x:log(x)/log(2)
  9. ent = 0.0
  10. result = [0 for i in xrange(2)]
  11. for s in sets:
  12. result[s[19]] += 1
  13. for i in xrange(2):
  14. p = float(result[i])/len(sets)
  15. if p != 0:ent -= p*log2(p)
  16. return ent
  17. # 分散
  18. #def variance(sets):
  19. # if len(sets)==0:return 0
  20. # data = [float(s[len(s)-1])for s in sets]
  21. # mean = sum(data)/len(data)
  22. # var = sum([(d-mean)**2 for d in data])/len(data)
  23. # return var
  24.  
  25. def divide(sets,index,criteria):
  26. l = []
  27. r = []
  28. for s in sets:
  29. if s[index] in criteria : l.append(s)
  30. else : r.append(s)
  31. return (l,r)
  32.  
  33. def build(n,sets):
  34.  
  35. tree[n]['sets']=sets
  36. if len(sets)==0:return
  37.  
  38. best_gain = 0.0
  39. best_criteria = {'left':-1,'right':-1,'poison':0,'sets':sets}
  40. cur = entropy(sets)
  41.  
  42. for i in xrange(19):
  43. criteria = {'left':-1,'right':-1,'index':i,'set':[]}
  44. a = [[0 for j in xrange(2)]for k in xrange(9)]
  45. for s in sets:
  46. a[s[i]][s[19]] += 1
  47. for j in xrange(9):
  48. if a[j][0] > a[j][1]:
  49. criteria['set'].append(j)
  50. l,r = divide(sets,criteria['index'],criteria['set'])
  51. p = float(len(l)/len(sets))
  52. gain = cur - (entropy(l)*p+entropy(r)*(1.0-p))
  53. if(gain > best_gain ):
  54. best_gain = gain
  55. best_criteria = criteria
  56.  
  57. if best_gain<=0.0:return
  58.  
  59. tree[n]=best_criteria
  60. tree[n]['left']=len(tree)
  61. tree.append({'left':-1,'right':-1,'poison':0,'sets':l})
  62. tree[n]['right']=len(tree)
  63. tree.append({'left':-1,'right':-1,'poison':1,'sets':r})
  64. l,r = divide(sets,tree[n]['index'],tree[n]['set'])
  65. build(tree[n]['left'],l)
  66. build(tree[n]['right'],r)
  67. return
  68.  
  69. def classify(n,sets):
  70. if tree[n]['left']==-1:return tree[n]['poison']
  71. if sets[tree[n]['index']] in tree[n]['set']:return classify(tree[n]['left'],sets)
  72. else : return classify(tree[n]['right'],sets)
  73.  
  74. def printtree(n,indent=''):
  75. if 'poison' in tree[n]:
  76. print tree[n]['poison']
  77. else:
  78. print str(tree[n]['index'])+':'+str(tree[n]['set'])+'? '
  79. print indent+'T->',
  80. printtree(tree[n]['left'],indent+' ')
  81. print indent+'F->',
  82. printtree(tree[n]['right'],indent+' ')
  83.  
  84. if __name__ == '__main__':
  85.  
  86. f = open('learn.txt','r')
  87. learn = []
  88. for line in f:
  89. learn.append(map(int,line.split(',')))
  90. f.close()
  91.  
  92. build(0,learn)
  93.  
  94. t = open('test.txt','r')
  95. test = []
  96. for line in t:
  97. test.append(map(int,line.split(',')))
  98. t.close()
  99.  
  100. printtree(0)
  101.  
  102. for i in test:
  103. print classify(0,i)
  104.  
Runtime error #stdin #stdout #stderr 0.1s 8904KB
stdin
Standard input is empty
stdout
Standard output is empty
stderr
Traceback (most recent call last):
  File "prog.py", line 86, in <module>
IOError: [Errno 2] No such file or directory: 'learn.txt'