use std::fmt;
struct LinearFunction {
a: f64,
b: f64,
}
impl LinearFunction {
fn new(a: f64, b: f64) -> LinearFunction {
LinearFunction {
a: a,
b: b,
}
}
fn feed(&self, x: f64) -> f64 {
self.a * x + self.b
}
fn learn(&mut self, input: f64, target_output: f64, learning_rate: f64) -> f64 {
let error = target_output - self.feed(input);
self.a -= error * -input * learning_rate;
self.b -= error * -1.0 * learning_rate;
return error;
}
}
impl fmt::Display for LinearFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "f(x) = {} * x + {}", self.a, self.b)
}
}
fn main() {
let mut fun = LinearFunction::new(0.0, 0.0);
println!("Before");
let data_list = [
(0.0, 1.0),
(1.0, 3.0),
(2.0, 5.0),
(3.0, 7.0),
];
let learning_rate = 0.01 / data_list.len() as f64;
for epoch in 0..10000 {
let mut mse = 0_f64;
for data in data_list.iter() {
let error = fun.learn(data.0, data.1, learning_rate);
mse += error * error;
}
if epoch % 1000 == 0 {
println!("{}", fun);
println!("Error : {}", mse / data_list.len() as f64);
}
}
println!("After");
println!("{}", fun);
}
dXNlIHN0ZDo6Zm10OwoKc3RydWN0IExpbmVhckZ1bmN0aW9uIHsKICAgIGE6IGY2NCwKICAgIGI6IGY2NCwKfQoKaW1wbCBMaW5lYXJGdW5jdGlvbiB7CiAgICBmbiBuZXcoYTogZjY0LCBiOiBmNjQpIC0+IExpbmVhckZ1bmN0aW9uIHsKICAgICAgICBMaW5lYXJGdW5jdGlvbiB7CiAgICAgICAgICAgIGE6IGEsCiAgICAgICAgICAgIGI6IGIsCiAgICAgICAgfQogICAgfQoKICAgIGZuIGZlZWQoJnNlbGYsIHg6IGY2NCkgLT4gZjY0IHsKICAgICAgICBzZWxmLmEgKiB4ICsgc2VsZi5iCiAgICB9CgogICAgZm4gbGVhcm4oJm11dCBzZWxmLCBpbnB1dDogZjY0LCB0YXJnZXRfb3V0cHV0OiBmNjQsIGxlYXJuaW5nX3JhdGU6IGY2NCkgLT4gZjY0IHsKICAgICAgICBsZXQgZXJyb3IgPSB0YXJnZXRfb3V0cHV0IC0gc2VsZi5mZWVkKGlucHV0KTsKICAgICAgICAKICAgICAgICBzZWxmLmEgLT0gZXJyb3IgKiAtaW5wdXQgKiBsZWFybmluZ19yYXRlOwogICAgICAgIHNlbGYuYiAtPSBlcnJvciAqIC0xLjAgKiBsZWFybmluZ19yYXRlOwogICAgICAgIAogICAgICAgIHJldHVybiBlcnJvcjsKICAgIH0KfQoKaW1wbCBmbXQ6OkRpc3BsYXkgZm9yIExpbmVhckZ1bmN0aW9uIHsKICAgIGZuIGZtdCgmc2VsZiwgZjogJm11dCBmbXQ6OkZvcm1hdHRlcikgLT4gZm10OjpSZXN1bHQgewogICAgICAgIHdyaXRlIShmLCAiZih4KSA9IHt9ICogeCArIHt9Iiwgc2VsZi5hLCBzZWxmLmIpCiAgICB9Cn0KCmZuIG1haW4oKSB7CiAgICBsZXQgbXV0IGZ1biA9IExpbmVhckZ1bmN0aW9uOjpuZXcoMC4wLCAwLjApOwogICAgCiAgICBwcmludGxuISgiQmVmb3JlIik7CiAgICAKICAgIGxldCBkYXRhX2xpc3QgPSBbCiAgICAgICAgKDAuMCwgMS4wKSwKICAgICAgICAoMS4wLCAzLjApLAogICAgICAgICgyLjAsIDUuMCksCiAgICAgICAgKDMuMCwgNy4wKSwKICAgIF07CiAgICAKICAgIGxldCBsZWFybmluZ19yYXRlID0gMC4wMSAvIGRhdGFfbGlzdC5sZW4oKSBhcyBmNjQ7CiAgICAKICAgIGZvciBlcG9jaCBpbiAwLi4xMDAwMCB7CiAgICAgICAgbGV0IG11dCBtc2UgPSAwX2Y2NDsKICAgIAogICAgICAgIGZvciBkYXRhIGluIGRhdGFfbGlzdC5pdGVyKCkgewogICAgICAgICAgICBsZXQgZXJyb3IgPSBmdW4ubGVhcm4oZGF0YS4wLCBkYXRhLjEsIGxlYXJuaW5nX3JhdGUpOwogICAgICAgICAgICBtc2UgKz0gZXJyb3IgKiBlcnJvcjsKICAgICAgICB9CiAgICAgICAgCiAgICAgICAgaWYgZXBvY2ggJSAxMDAwID09IDAgewogICAgICAgICAgICBwcmludGxuISgie30iLCBmdW4pOwogICAgICAgICAgICBwcmludGxuISgiRXJyb3IgOiB7fSIsIG1zZSAvIGRhdGFfbGlzdC5sZW4oKSBhcyBmNjQpOwogICAgICAgIH0KICAgIH0KICAgIAogICAgcHJpbnRsbiEoIkFmdGVyIik7CiAgICBwcmludGxuISgie30iLCBmdW4pOwp9