fork download
  1. use std::fmt;
  2.  
  3. struct LinearFunction {
  4. a: f64,
  5. b: f64,
  6. }
  7.  
  8. impl LinearFunction {
  9. fn new(a: f64, b: f64) -> LinearFunction {
  10. LinearFunction {
  11. a: a,
  12. b: b,
  13. }
  14. }
  15.  
  16. fn feed(&self, x: f64) -> f64 {
  17. self.a * x + self.b
  18. }
  19.  
  20. fn learn(&mut self, input: f64, target_output: f64, learning_rate: f64) -> f64 {
  21. let error = target_output - self.feed(input);
  22.  
  23. self.a -= error * -input * learning_rate;
  24. self.b -= error * -1.0 * learning_rate;
  25.  
  26. return error;
  27. }
  28. }
  29.  
  30. impl fmt::Display for LinearFunction {
  31. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  32. write!(f, "f(x) = {} * x + {}", self.a, self.b)
  33. }
  34. }
  35.  
  36. fn main() {
  37. let mut fun = LinearFunction::new(0.0, 0.0);
  38.  
  39. println!("Before");
  40.  
  41. let data_list = [
  42. (0.0, 1.0),
  43. (1.0, 3.0),
  44. (2.0, 5.0),
  45. (3.0, 7.0),
  46. ];
  47.  
  48. let learning_rate = 0.01 / data_list.len() as f64;
  49.  
  50. for epoch in 0..10000 {
  51. let mut mse = 0_f64;
  52.  
  53. for data in data_list.iter() {
  54. let error = fun.learn(data.0, data.1, learning_rate);
  55. mse += error * error;
  56. }
  57.  
  58. if epoch % 1000 == 0 {
  59. println!("{}", fun);
  60. println!("Error : {}", mse / data_list.len() as f64);
  61. }
  62. }
  63.  
  64. println!("After");
  65. println!("{}", fun);
  66. }
Success #stdin #stdout 0.01s 4388KB
stdin
Standard input is empty
stdout
Before
f(x) = 0.0839723100390625 * x + 0.039632452304687496
Error : 20.519144377863604
f(x) = 2.0011632736774065 * x + 0.9974990391031798
Error : 0.0000022823391159126393
f(x) = 2.000058928111989 * x + 0.9998733084856376
Error : 0.000000005856813909427799
f(x) = 2.0000029851293366 * x + 0.9999935821708257
Error : 0.000000000015029435769872653
f(x) = 2.0000001512181007 * x + 0.9999996748911606
Error : 0.00000000000003856771653562187
f(x) = 2.000000007660276 * x + 0.999999983530919
Error : 0.0000000000000000989703538043435
f(x) = 2.0000000003880474 * x + 0.9999999991657231
Error : 0.00000000000000000025397248558118256
f(x) = 2.0000000000196576 * x + 0.9999999999577371
Error : 0.0000000000000000000006517520649855388
f(x) = 2.000000000000997 * x + 0.9999999999978586
Error : 0.0000000000000000000000016736991178990983
f(x) = 2.0000000000000444 * x + 0.9999999999998948
Error : 0.000000000000000000000000003977942048141749
After
f(x) = 2.000000000000026 * x + 0.9999999999999513