fork download
  1. open System
  2.  
  3. let random = new Random()
  4.  
  5. let flip f a b = f b a
  6.  
  7. let keepLeft f (x,y) = x, f y
  8.  
  9. let keepRight f (x,y) = f x, y
  10.  
  11. let fst3 (a,_,_) = a
  12.  
  13. let fst_snd3 (a,b,_) = a,b
  14.  
  15. let scaleTo rmin rmax rangemin rangemax value =
  16. let adjrmin, adjrmax, adjval =
  17. if rangemin < 0. then 0., -rangemin + rangemax , -rangemin + value
  18. else rangemin, rangemax , value //translate to 0
  19.  
  20. (adjval - adjrmin)/(adjrmax - adjrmin) * (rmax-rmin) + rmin
  21.  
  22. module Array =
  23. let inline normalize (a: _[]) =
  24. let tot = Array.sum a
  25. Array.map (flip (/) tot) a
  26.  
  27. let inline normalizeWeights (a: ('a * 'b) []) =
  28. let tot = Array.sumBy snd a
  29. Array.map (keepLeft (flip (/) tot)) a
  30.  
  31.  
  32. let cdf p =
  33. p |> Array.fold (fun (total, list) v ->
  34. let cd = v + total
  35. cd, cd::list) (0., [])
  36. |> snd
  37. |> Array.ofList
  38.  
  39. let getDiscreteSampleFromCDF (pcdf:float[]) =
  40. let k, pcdlen = random.NextDouble() * pcdf.[0], pcdf.Length - 1
  41.  
  42. let rec cummProb idx = if k > pcdf.[idx] then cummProb (idx - 1) else idx
  43.  
  44. abs(cummProb pcdlen - pcdlen)
  45.  
  46.  
  47. let discreteSample p = cdf p |> getDiscreteSampleFromCDF
  48.  
  49.  
  50. let round (n:int) (x:float) = Math.Round(x,n)
  51.  
  52.  
  53. //////////////////////////////
  54.  
  55. type Go = ``Do 🛑`` | ``Do 🚙``
  56.  
  57. type Decision = ``Do 😏`` of float [] | ``Do 🚦``
  58.  
  59. let moves = [|``Do 🛑``;``Do 🚙``|]
  60.  
  61. let multiplicativeWeightsUpdate rate minAmount maxAmount (oldweights:float []) (results:seq<float>) =
  62. let lossbounded = scaleTo -1. 1. minAmount maxAmount
  63. results |> Seq.mapi (fun i r -> oldweights.[i] * (1. + rate * lossbounded (min r maxAmount)))
  64. |> Seq.toArray
  65. |> Array.normalize
  66.  
  67. let learningExpert2 prevWeights rate res = multiplicativeWeightsUpdate rate -100. 1. prevWeights res
  68.  
  69. let sampleLight p = if random.NextDouble() < p then ``Do 🛑``,``Do 🚙`` else ``Do 🚙``,``Do 🛑``
  70.  
  71. let randomLight () = if random.NextDouble() < 0.5 then ``Do 🛑`` else ``Do 🚙``
  72. let faultyLight p2 p = if random.NextDouble() < p then randomLight(),randomLight() else sampleLight p2
  73.  
  74.  
  75. //============
  76.  
  77. let drive =
  78. function
  79. | (``Do 🛑``,``Do 🛑``) -> (0.,0.)
  80. | (``Do 🛑``,``Do 🚙``) -> (0.,1.)
  81. | (``Do 🚙``,``Do 🛑``) -> (1.,0.)
  82. | (``Do 🚙``,``Do 🚙``) -> (-100.,-100.)
  83.  
  84. let inline getExpertVerbose indicated signal (rules:_[]) =
  85. printfn "Signal = %A" signal
  86. let other = match indicated with None -> "N/A" | Some a -> string a
  87. printfn "Other signal: %s" other
  88.  
  89. let cs = [|for (p,r,f) in rules do
  90. match f indicated signal with
  91. | Some e -> printfn "Rule %s with weight %A matched" r p; yield e, p
  92. | None -> ()|]
  93. |> Array.normalizeWeights
  94.  
  95. printfn "Normalized %A" cs
  96. let cs' = Array.groupBy fst cs |> Array.map (fun (s,c) -> s, Array.sumBy snd c)
  97. printfn "%A" cs'
  98. let ps = Array.map snd cs
  99.  
  100. fst (cs.[discreteSample ps])
  101.  
  102. let inline getExpert2 indicated signal (rules:_[]) =
  103. let cs = [|for (p,_,f) in rules do
  104. match f indicated signal with
  105. | Some c -> yield c,p
  106. | None -> ()|] |> Array.normalizeWeights
  107. let ps = Array.map snd cs
  108. fst(cs.[discreteSample ps])
  109.  
  110.  
  111. let learnExperts3 signal rate opmove (rules:_[]) =
  112. let costs =
  113. [|for (p,_,f) in rules do
  114. match f (Some opmove) signal with //<== change (Some opmove) to (None), see what happens
  115. | Some choice -> yield (p * fst(drive (choice,opmove)))
  116. | None -> yield 0.|]
  117.  
  118. let w = Array.map fst3 rules
  119.  
  120. Array.Copy(learningExpert2 w rate costs,w,w.Length)
  121.  
  122. for i in 0..w.Length - 1 do
  123. let _,name,f = rules.[i]
  124. rules.[i] <- w.[i],name,f
  125.  
  126.  
  127. let learnExpertsVerbose signal rate opmove (rules:_[]) =
  128. let costs = [|for (p,hh,f) in rules do
  129. match f (Some opmove) signal with
  130. | Some choice ->
  131. printfn "%A" ((p ,hh,opmove,choice, fst(drive (choice,opmove))) )
  132. yield (p * fst(drive (choice,opmove)))
  133. | None -> yield 0.|]
  134. let w = Array.map fst3 rules
  135. Array.Copy(learningExpert2 w rate costs,w,w.Length)
  136. for i in 0..w.Length - 1 do
  137. let _,name,f = rules.[i]
  138. rules.[i] <- w.[i],name,f
  139.  
  140.  
  141. let rules =
  142. let num_moves = (float moves.Length)*(float moves.Length)*(float moves.Length)
  143. [|for m in moves do
  144. for m2 in moves do
  145. for m3 in moves ->
  146. let rulename = sprintf "if signal=%A && other=%A then %A" m m2 m3
  147.  
  148. 1./num_moves,
  149. rulename,
  150. fun other signal ->
  151. match other with
  152. | None -> if signal = m then Some m3 else None
  153. | Some indicated -> if signal = m && indicated = m2 then Some m3 else None|]
  154.  
  155.  
  156. let gatherStats n look1 look2 otherweight heroweights =
  157. let lists = ResizeArray()
  158. for i in 0..n do
  159. let light1, light2 = sampleLight 0.5
  160. let p1mov = getExpert2 (if look1 then Some light2 else None) light1 heroweights
  161. let p2mov = getExpert2 (if look2 then Some p1mov else None) light2 otherweight
  162. lists.Add(p1mov,p2mov)
  163. lists
  164.  
  165. let round100 x = sprintf "%A%%" (round 2 ((float x)*100.))
  166.  
  167. let learner2 () =
  168. let mutable heroweights = Array.map id rules
  169. let mutable otherweight = Array.map id rules
  170.  
  171. let rate = 0.5
  172.  
  173. for _ in 0..99999 do
  174. let light1, light2 = sampleLight 0.5
  175.  
  176. let p1mov = getExpert2 (Some light2) light1 heroweights // <== Change (Some light2) to None
  177. let p2mov = getExpert2 (Some p1mov) light2 otherweight
  178. learnExperts3 light1 rate p2mov heroweights
  179. learnExperts3 light2 rate p1mov otherweight
  180.  
  181. let gathered = gatherStats 999 true true otherweight heroweights
  182.  
  183. gathered.ToArray()
  184. |> Array.groupBy id
  185. |> Array.map (keepLeft (Array.length >> float))
  186. |> Array.normalizeWeights
  187. |> Array.map (keepLeft round100)
  188. |> printfn "%A\n"
  189.  
  190. printfn "Rules 1: %A" (Array.map fst_snd3 heroweights |> Array.filter (fst >> round 2 >> (<>) 0.) |> Array.map (keepRight round100))
  191. printfn "Rules 2: %A" (Array.map fst_snd3 otherweight |> Array.filter (fst >> round 2 >> (<>) 0.) |> Array.map (keepRight round100))
  192.  
  193. learner2()
  194.  
Success #stdin #stdout 0.75s 135232KB
stdin
Standard input is empty
stdout
[|((Do 🚙, Do 🛑), "48.2%"); ((Do 🛑, Do 🚙), "51.0%"); ((Do 🚙, Do 🚙), "0.8%")|]

Rules 1: [|("100.0%", "if signal=Do 🚙 && other=Do 🛑 then Do 🚙")|]
Rules 2: [|("100.0%", "if signal=Do 🚙 && other=Do 🛑 then Do 🚙")|]