fork download
  1. import java.util.Random;
  2.  
  3. class Example{
  4. double input[];
  5. double output;
  6.  
  7. Example(double[] in, double out){
  8. this.input = in;
  9. this.output = out;
  10. }
  11. }
  12.  
  13. class HiddenSigmoidUnit{
  14. double output, error;
  15. double input[];
  16. double w[];
  17. double n = 0.05;
  18. double bias = 1;
  19.  
  20. static Random rand = new Random();
  21. HiddenSigmoidUnit(){
  22. w = new double[2];
  23. int s = rand.nextBoolean()?1:-1;
  24. for(int i=0;i<w.length;++i)w[i] = s*rand.nextDouble();
  25. }
  26.  
  27. public double computeOutput(double[] input){
  28. this.input = input;
  29. output = bias+w[0]*input[0] + w[1]*input[1];
  30. output = 1/(1+Math.pow(Math.E, -output));
  31. return output;
  32. }
  33.  
  34. public double computeError(double w, double outputUnitError){
  35. error = (output)*(1-output)*(outputUnitError*w);
  36. return error;
  37. }
  38.  
  39. public void fixError(){
  40. for(int i=0;i<input.length;++i) w[i] += n*error*input[i];
  41. }
  42. }
  43.  
  44. class OutputUnit{
  45. double t, output, error;
  46. double input[];
  47. double w[];
  48. double bias=1;
  49. double n = 0.05;
  50.  
  51. static Random rand = new Random();
  52.  
  53. OutputUnit(int hUnits){
  54. w = new double[hUnits];
  55. int s = rand.nextBoolean()?1:-1;
  56. //Weights between -1 to 1
  57. for(int i=0;i<w.length;++i)w[i] = s*rand.nextDouble();
  58. }
  59.  
  60. public void computeOutput(double[] input) {
  61. this.input = input;
  62. output = bias+input[0]*w[0]+input[1]*w[1];
  63. output = 1/(1+Math.pow(Math.E, -output));
  64. }
  65.  
  66. public void computeError(double t){
  67. this.t = t;
  68. error = (output)*(1-output)*(t-output);
  69. }
  70.  
  71. public void fixError() {
  72. for(int i=0;i<w.length;++i) w[i] += n*error*input[i];
  73. }
  74. }
  75.  
  76. class Net{
  77. static Random rand = new Random();
  78. static OutputUnit outputUnit;
  79. static HiddenSigmoidUnit[] hlayer;
  80.  
  81. static Example[] examples = {
  82. new Example(new double[]{0, 0}, 0),
  83. new Example(new double[]{0, 1}, 1),
  84. new Example(new double[]{1, 0}, 1),
  85. new Example(new double[]{1, 1}, 0)
  86. };
  87.  
  88. public static void main(String[] args){
  89. //units in hidden layer
  90. int nhidden = 2;
  91.  
  92. //Holds hidden layers' units
  93. hlayer = new HiddenSigmoidUnit[nhidden];
  94.  
  95. //Initialize
  96. for(int i=0;i<hlayer.length;++i)hlayer[i] = new HiddenSigmoidUnit();
  97. outputUnit = new OutputUnit(hlayer.length);
  98.  
  99. //test multiple times for convergence to optimal values
  100. for(int check=0;check<100;++check){
  101.  
  102. for(int i=0;i<hlayer.length;++i)hlayer[i] = new HiddenSigmoidUnit();
  103. outputUnit = new OutputUnit(hlayer.length);
  104.  
  105. //iteration count
  106. for(int iteration = 0;;++iteration){
  107.  
  108. //if count == 4, all the examples were classified correctly, break
  109. int count=0;
  110.  
  111. //Shuffle examples
  112. /*
  113.   for(int i=0;i<examples.length;++i){
  114.   for(int j=i+1;j<examples.length;++j){
  115.   int choose = rand.nextInt(j);
  116.   Example t = examples[choose];
  117.   examples[choose] = examples[j];
  118.   examples[j] = t;
  119.   }
  120.   }
  121.   */
  122.  
  123. //for each example
  124. for(int i=0;i<examples.length;++i){
  125.  
  126. Example example = examples[i];
  127.  
  128. //collect the outputs from hidden layer to pass on to output unit
  129. double[] outputsHLayer = new double[hlayer.length];
  130. for(int j=0;j<hlayer.length;++j) outputsHLayer[j] = hlayer[j].computeOutput(example.input);
  131.  
  132. //pass to output unit
  133. outputUnit.computeOutput(outputsHLayer);
  134.  
  135. //Compute Errors for output and hidden layer units
  136. //Passing the true output, compute outputUnit error
  137. outputUnit.computeError(example.output);
  138.  
  139. //compute hidden layer - units' error
  140. for(int j=0;j<hlayer.length;++j) hlayer[j].computeError(outputUnit.w[j], outputUnit.error);
  141.  
  142. //fix errors
  143. outputUnit.fixError();
  144. for(int j=0;j<hlayer.length;++j) hlayer[j].fixError();
  145.  
  146. //check if output units result matches the current example
  147. int o = outputUnit.output>0.5?1:0;
  148. if(o==example.output)count++;
  149. } //end-for each example
  150.  
  151. if(iteration>90000){
  152. //Probably not going to converge to optimal values
  153. System.out.println("\nIterations > 90000, stop...");
  154. displayOutputs();
  155. break;
  156. }
  157.  
  158. //if it matches all examples, stop training
  159. if(count==examples.length){
  160. System.out.println("\nTraining complete. No of iterations = "+iteration);
  161. displayOutputs();
  162. break;
  163. }
  164. } //end for-each
  165. }//end for-each re-initialization
  166. }//end main
  167.  
  168. //Apply the learned weights to all the examples
  169. public static void displayOutputs(){
  170. System.out.println("Displaying outputs for all examples... ");
  171. double[] outputsHLayer = new double[hlayer.length];
  172.  
  173. for(int e=0;e<examples.length;++e){
  174. for(int j=0;j<hlayer.length;++j) {
  175. outputsHLayer[j] = hlayer[j].computeOutput(examples[e].input);
  176. }
  177. outputUnit.computeOutput(outputsHLayer);
  178. System.out.println(outputUnit.output);
  179. }
  180. }
  181. }
Time limit exceeded #stdin #stdout 5s 321152KB
stdin
Standard input is empty
stdout
Iterations > 90000, stop...
Displaying outputs for all examples... 
0.018861254512881773
0.7270271284494716
0.5007550527204925
0.5024353957353963

Training complete. No of iterations = 45076
Displaying outputs for all examples... 
0.3944511789979849
0.5033004761575361
0.5008283246200929
0.2865272493546562

Training complete. No of iterations = 39707
Displaying outputs for all examples... 
0.39455754434259843
0.5008762488126696
0.5029579167912538
0.28715696580224176

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.43116164638530535
0.32096730276984053
0.9758219334403757
0.32228953888593287

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6166839799921503
0.5666948533039026
0.7323536028205314
0.03167884683374702

Training complete. No of iterations = 40841
Displaying outputs for all examples... 
0.39452717661115266
0.503094090661865
0.5008269212481119
0.28697118322186377

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6166211878884533
0.5668422625954125
0.7319327530115307
0.03077740429020573

Training complete. No of iterations = 41710
Displaying outputs for all examples... 
0.39443922904355433
0.5008752332069658
0.5032812980593797
0.2864679610647368

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.43118509762935275
0.32094810965348286
0.9764730159008141
0.3221952972850956

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6167919275615609
0.7330874634896468
0.566481324356442
0.03290715947755907

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6166772062595632
0.732246441498517
0.5667093511399082
0.0316086994910546

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6165821595267033
0.5669419656632928
0.7317498131260377
0.030098435644209765

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6173872457250869
0.7385064554164004
0.5647858195218177
0.042152846626878106

Training complete. No of iterations = 37937
Displaying outputs for all examples... 
0.3948110109910797
0.5023416631102245
0.5008218796461995
0.288600901324653

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6214283575714268
0.828001424205262
0.5472257959732171
0.09039462584473663

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6168468487773882
0.7333753534654978
0.5662073498563921
0.03471157734417047

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.4311238629451931
0.9766748031029869
0.3210037661745261
0.3223167655535901

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6177011028040171
0.7434187753032622
0.5636984082366729
0.04707421008440403

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6165997690220514
0.7316423493083851
0.5668272160520887
0.031058153880493645

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6167119890444502
0.5666504826787886
0.7325215262738196
0.031932607591875

Training complete. No of iterations = 39411
Displaying outputs for all examples... 
0.39455929549097385
0.5030074092853638
0.5008264331120219
0.2871578135013802

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.617076271997498
0.5657753887361301
0.7355093958724654
0.036942033416729254

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6169351769787368
0.734160375704149
0.566141018929304
0.03490487888174936

Training complete. No of iterations = 47156
Displaying outputs for all examples... 
0.3942723768016609
0.5008736806598922
0.5037466129073506
0.28548230033553007

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6170136141785073
0.7348556233882326
0.565964545100795
0.0358784043408415

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6172957219603963
0.7375916163519459
0.5651357193470172
0.04035062848635182

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.43111258425836774
0.321007458144265
0.9765543047063316
0.32233300776963775

Iterations > 90000, stop...
Displaying outputs for all examples... 
0.6166856317593773
0.5666261287609559
0.7321648472512584
0.03230124431402712