import java.util.Random;
class Example{
double input[];
double output;
Example(double[] in, double out){
this.input = in;
this.output = out;
}
}
class HiddenSigmoidUnit{
double output, error;
double input[];
double w[];
double n = 0.05;
double bias = 1;
HiddenSigmoidUnit(){
w = new double[2];
int s = rand.nextBoolean()?1:-1;
for(int i=0;i<w.length;++i)w[i] = s*rand.nextDouble();
}
public double computeOutput(double[] input){
this.input = input;
output = bias+w[0]*input[0] + w[1]*input[1];
output
= 1/(1+Math.
pow(Math.
E,
-output
)); return output;
}
public double computeError(double w, double outputUnitError){
error = (output)*(1-output)*(outputUnitError*w);
return error;
}
public void fixError(){
for(int i=0;i<input.length;++i) w[i] += n*error*input[i];
}
}
class OutputUnit{
double t, output, error;
double input[];
double w[];
double bias=1;
double n = 0.05;
OutputUnit(int hUnits){
w = new double[hUnits];
int s = rand.nextBoolean()?1:-1;
//Weights between -1 to 1
for(int i=0;i<w.length;++i)w[i] = s*rand.nextDouble();
}
public void computeOutput(double[] input) {
this.input = input;
output = bias+input[0]*w[0]+input[1]*w[1];
output
= 1/(1+Math.
pow(Math.
E,
-output
)); }
public void computeError(double t){
this.t = t;
error = (output)*(1-output)*(t-output);
}
public void fixError() {
for(int i=0;i<w.length;++i) w[i] += n*error*input[i];
}
}
class Net{
static OutputUnit outputUnit;
static HiddenSigmoidUnit[] hlayer;
static Example[] examples = {
new Example(new double[]{0, 0}, 0),
new Example(new double[]{0, 1}, 1),
new Example(new double[]{1, 0}, 1),
new Example(new double[]{1, 1}, 0)
};
public static void main
(String[] args
){ //units in hidden layer
int nhidden = 2;
//Holds hidden layers' units
hlayer = new HiddenSigmoidUnit[nhidden];
//Initialize
for(int i=0;i<hlayer.length;++i)hlayer[i] = new HiddenSigmoidUnit();
outputUnit = new OutputUnit(hlayer.length);
//test multiple times for convergence to optimal values
for(int check=0;check<100;++check){
for(int i=0;i<hlayer.length;++i)hlayer[i] = new HiddenSigmoidUnit();
outputUnit = new OutputUnit(hlayer.length);
//iteration count
for(int iteration = 0;;++iteration){
//if count == 4, all the examples were classified correctly, break
int count=0;
//Shuffle examples
/*
for(int i=0;i<examples.length;++i){
for(int j=i+1;j<examples.length;++j){
int choose = rand.nextInt(j);
Example t = examples[choose];
examples[choose] = examples[j];
examples[j] = t;
}
}
*/
//for each example
for(int i=0;i<examples.length;++i){
Example example = examples[i];
//collect the outputs from hidden layer to pass on to output unit
double[] outputsHLayer = new double[hlayer.length];
for(int j=0;j<hlayer.length;++j) outputsHLayer[j] = hlayer[j].computeOutput(example.input);
//pass to output unit
outputUnit.computeOutput(outputsHLayer);
//Compute Errors for output and hidden layer units
//Passing the true output, compute outputUnit error
outputUnit.computeError(example.output);
//compute hidden layer - units' error
for(int j=0;j<hlayer.length;++j) hlayer[j].computeError(outputUnit.w[j], outputUnit.error);
//fix errors
outputUnit.fixError();
for(int j=0;j<hlayer.length;++j) hlayer[j].fixError();
//check if output units result matches the current example
int o = outputUnit.output>0.5?1:0;
if(o==example.output)count++;
} //end-for each example
if(iteration>90000){
//Probably not going to converge to optimal values
System.
out.
println("\nIterations > 90000, stop..."); displayOutputs();
break;
}
//if it matches all examples, stop training
if(count==examples.length){
System.
out.
println("\nTraining complete. No of iterations = "+iteration
); displayOutputs();
break;
}
} //end for-each
}//end for-each re-initialization
}//end main
//Apply the learned weights to all the examples
public static void displayOutputs(){
System.
out.
println("Displaying outputs for all examples... "); double[] outputsHLayer = new double[hlayer.length];
for(int e=0;e<examples.length;++e){
for(int j=0;j<hlayer.length;++j) {
outputsHLayer[j] = hlayer[j].computeOutput(examples[e].input);
}
outputUnit.computeOutput(outputsHLayer);
System.
out.
println(outputUnit.
output); }
}
}