diff --git a/.gitignore b/.gitignore
index c9c8d50..835ea08 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,8 +1,9 @@
-.idea
+.*
+*.iml
target
out
artifacts
nn_data.json
src/main/java/basicneuralnetwork/Test.java
src/main/java/basicneuralnetwork/activationfunctions/SoftmaxActivationFunction.java
-nn_data.json
\ No newline at end of file
+nn_data.json
diff --git a/pom.xml b/pom.xml
index ca6108f..d47f959 100644
--- a/pom.xml
+++ b/pom.xml
@@ -51,6 +51,12 @@
gson
2.8.4
+
+ junit
+ junit
+ 4.12
+ test
+
diff --git a/src/main/java/basicneuralnetwork/NeuralNetwork.java b/src/main/java/basicneuralnetwork/NeuralNetwork.java
index 4115b2e..8ae8865 100644
--- a/src/main/java/basicneuralnetwork/NeuralNetwork.java
+++ b/src/main/java/basicneuralnetwork/NeuralNetwork.java
@@ -13,8 +13,6 @@
*/
public class NeuralNetwork {
- private ActivationFunctionFactory activationFunctionFactory = new ActivationFunctionFactory();
-
private Random random = new Random();
// Dimensions of the neural network
@@ -114,7 +112,7 @@ public double[] guess(double[] input) {
throw new WrongDimensionException(input.length, inputNodes, "Input");
} else {
// Get ActivationFunction-object from the map by key
- ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey);
+ ActivationFunction activationFunction = ActivationFunctionFactory.createByName(activationFunctionKey);
// Transform array to matrix
SimpleMatrix output = MatrixUtilities.arrayToMatrix(input);
@@ -134,7 +132,7 @@ public void train(double[] inputArray, double[] targetArray) {
throw new WrongDimensionException(targetArray.length, outputNodes, "Output");
} else {
// Get ActivationFunction-object from the map by key
- ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey);
+ ActivationFunction activationFunction = ActivationFunctionFactory.createByName(activationFunctionKey);
// Transform 2D array to matrix
SimpleMatrix input = MatrixUtilities.arrayToMatrix(inputArray);
@@ -270,10 +268,6 @@ public void setActivationFunction(String activationFunction) {
this.activationFunctionKey = activationFunction;
}
- public void addActivationFunction(String key, ActivationFunction activationFunction){
- activationFunctionFactory.addActivationFunction(key, activationFunction);
- }
-
public double getLearningRate() {
return learningRate;
}
diff --git a/src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java b/src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java
new file mode 100644
index 0000000..d1f202b
--- /dev/null
+++ b/src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java
@@ -0,0 +1,58 @@
+package basicneuralnetwork;
+
+/**
+ * Created by MichalWa on 08.06.18
+ */
+public class NeuralNetworkBuilder {
+
+ private int inputNodes = 0;
+ private int hiddenLayers = 0;
+ private int hiddenNodes = 0;
+ private int outputNodes = 0;
+ private String activationFunction = null;
+ private double learningRate = -1.0;
+
+ public NeuralNetworkBuilder setInputNodes(int inputNodes) {
+ this.inputNodes = inputNodes;
+ return this;
+ }
+
+ public NeuralNetworkBuilder setHiddenLayers(int hiddenLayers) {
+ this.hiddenLayers = hiddenLayers;
+ return this;
+ }
+
+ public NeuralNetworkBuilder setHiddenNodes(int hiddenNodes) {
+ this.hiddenNodes = hiddenNodes;
+ return this;
+ }
+
+ public NeuralNetworkBuilder setOutputNodes(int outputNodes) {
+ this.outputNodes = outputNodes;
+ return this;
+ }
+
+ public NeuralNetworkBuilder setActivationFunction(String activationFunction) {
+ this.activationFunction = activationFunction;
+ return this;
+ }
+
+ public NeuralNetworkBuilder setLearningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ public NeuralNetwork create() {
+ if(inputNodes < 1) throw new IllegalStateException("There must be 1 or more input nodes.");
+ if(hiddenNodes < 1) throw new IllegalStateException("There must be 1 or more hidden nodes.");
+ if(outputNodes < 1) throw new IllegalStateException("There must be 1 or more output nodes");
+
+ NeuralNetwork nn = new NeuralNetwork(inputNodes, hiddenLayers, hiddenNodes, outputNodes);
+
+ if(activationFunction != null) nn.setActivationFunction(activationFunction);
+ if(learningRate != -1.0) nn.setLearningRate(learningRate);
+
+ return nn;
+ }
+
+}
diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java
index 2097a05..95e5e76 100644
--- a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java
+++ b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java
@@ -5,19 +5,38 @@
/**
* Created by KimFeichtinger on 20.04.18.
*/
-// This interface and it's methods have to be implemented in all ActivationFunction-classes
-public interface ActivationFunction {
-
- String SIGMOID = "SIGMOID";
- String TANH = "TANH";
- String RELU = "RELU";
+public abstract class ActivationFunction {
// Activation function
- SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input);
+ public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
+ SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
+ for (int i = 0; i < input.numRows(); i++) {
+ for (int j = 0; j < input.numCols(); j ++) {
+ double value = input.get(i, j);
+ output.set(i, j, apply(value));
+ }
+ }
+ return output;
+ }
// Derivative of activation function (not real derivative because Activation function has already been applied to the input)
- SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input);
-
- String getName();
-
+ public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
+ SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
+ for (int i = 0; i < input.numRows(); i++) {
+ for (int j = 0; j < input.numCols(); j ++) {
+ double value = input.get(i, j);
+ output.set(i, j, applyDerivative(value));
+ }
+ }
+ return output;
+ }
+
+ /** Applies the function to a single value */
+ protected abstract double apply(double value);
+
+ /** Applies the pseudo-derivative of the function to a single value */
+ protected abstract double applyDerivative(double value);
+
+ /** Returns the name of the function */
+ public abstract String getName();
}
diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java
index 356ff7c..d9b221c 100644
--- a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java
+++ b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java
@@ -2,31 +2,21 @@
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
+import java.util.function.Supplier;
/**
* Created by KimFeichtinger on 04.05.18.
*/
public class ActivationFunctionFactory {
- private Map activationFunctionMap = new HashMap<>();
-
- public ActivationFunctionFactory () {
- // Fill map with all the activation functions
- ActivationFunction sigmoid = new SigmoidActivationFunction();
- activationFunctionMap.put(sigmoid.getName(), sigmoid);
-
- ActivationFunction tanh = new TanhActivationFunction();
- activationFunctionMap.put(tanh.getName(), tanh);
-
- ActivationFunction relu = new ReLuActivationFunction();
- activationFunctionMap.put(relu.getName(), relu);
- }
-
- public ActivationFunction getActivationFunctionByKey (String activationFunctionKey) {
- return activationFunctionMap.get(activationFunctionKey);
+ private static Map> factories = new HashMap<>();
+
+ public static ActivationFunction createByName(String name) {
+ return Optional.ofNullable(factories.get(name)).map(Supplier::get).orElse(null);
}
- public void addActivationFunction(String key, ActivationFunction activationFunction) {
- activationFunctionMap.put(key, activationFunction);
+ public static void register(String key, Supplier factory) {
+ factories.put(key, factory);
}
}
diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java
index 5cb7709..70ca25c 100644
--- a/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java
+++ b/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java
@@ -1,46 +1,26 @@
package basicneuralnetwork.activationfunctions;
-import org.ejml.simple.SimpleMatrix;
-
/**
* Created by KimFeichtinger on 26.04.18.
*/
-public class ReLuActivationFunction implements ActivationFunction {
-
- private static final String NAME = "RELU";
-
- public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
- SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
+public class ReLuActivationFunction extends ActivationFunction {
- for (int i = 0; i < input.numRows(); i++) {
- // Column is always 0 because input has only one column
- double value = input.get(i, 0);
- double result = value > 0 ? value : 0;
-
- output.set(i, 0, result);
- }
-
- // Formula:
- // for input < 0: 0, else input
- return output;
+ public static final String NAME = "relu";
+
+ static {
+ ActivationFunctionFactory.register(NAME, ReLuActivationFunction::new);
}
-
- public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
- SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
-
- for (int i = 0; i < input.numRows(); i++) {
- // Column is always 0 because input has only one column
- double value = input.get(i, 0);
- double result = value > 0 ? 1 : 0;
-
- output.set(i, 0, result);
- }
-
- // Formula:
- // for input > 0: 1, else 0
- return output;
+
+ @Override
+ protected double apply(double value) {
+ return value > 0 ? value : 0;
}
-
+
+ @Override
+ protected double applyDerivative(double value) {
+ return value > 0 ? 1 : 0;
+ }
+
public String getName() {
return NAME;
}
diff --git a/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java
index 839c11e..6da8d92 100644
--- a/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java
+++ b/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java
@@ -1,48 +1,26 @@
package basicneuralnetwork.activationfunctions;
-import org.ejml.simple.SimpleMatrix;
-
/**
* Created by KimFeichtinger on 20.04.18.
*/
-public class SigmoidActivationFunction implements ActivationFunction {
-
- private static final String NAME = "SIGMOID";
-
- // Sigmoid
- public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
- SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
+public class SigmoidActivationFunction extends ActivationFunction {
- for (int i = 0; i < input.numRows(); i++) {
- // Column is always 0 because input has only one column
- double value = input.get(i, 0);
- double result = 1 / (1 + Math.exp(-value));
-
- output.set(i, 0, result);
- }
-
- // Formula:
- // 1 / (1 + Math.exp(-input));
- return output;
+ public static final String NAME = "sigmoid";
+
+ static {
+ ActivationFunctionFactory.register(NAME, SigmoidActivationFunction::new);
}
-
- // Derivative of Sigmoid (not real derivative because Activation function has already been applied to the input)
- public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
- SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
-
- for (int i = 0; i < input.numRows(); i++) {
- // Column is always 0 because input has only one column
- double value = input.get(i, 0);
- double result = value * (1 - value);
-
- output.set(i, 0, result);
- }
-
- // Formula:
- // input * (1 - input);
- return output;
+
+ @Override
+ protected double apply(double value) {
+ return 1 / (1 + Math.exp(-value));
}
-
+
+ @Override
+ protected double applyDerivative(double value) {
+ return value * (1 - value);
+ }
+
public String getName() {
return NAME;
}
diff --git a/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java
index b4e4c4d..2d87384 100644
--- a/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java
+++ b/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java
@@ -1,47 +1,26 @@
package basicneuralnetwork.activationfunctions;
-import org.ejml.simple.SimpleMatrix;
-
/**
* Created by KimFeichtinger on 20.04.18.
*/
-public class TanhActivationFunction implements ActivationFunction {
-
- private static final String NAME = "TANH";
-
- public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
- SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
-
- for (int i = 0; i < input.numRows(); i++) {
- // Column is always 0 because input has only one column
- double value = input.get(i, 0);
- double result = Math.tanh(value);
-
- output.set(i, 0, result);
- }
-
- // Formula:
- // 2 * (1 / (1 + Math.exp(2 * -input))) - 1;
- // Math.tanh(input);
- return output;
+public class TanhActivationFunction extends ActivationFunction {
+
+ public static final String NAME = "tanh";
+
+ static {
+ ActivationFunctionFactory.register(NAME, TanhActivationFunction::new);
}
-
- public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
- SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
-
- for (int i = 0; i < input.numRows(); i++) {
- // Column is always 0 because input has only one column
- double value = input.get(i, 0);
- double result = 1 - (value * value);
-
- output.set(i, 0, result);
- }
-
- // Formula:
- // 1 - (input * input);
- return output;
+
+ @Override
+ protected double apply(double value) {
+ return Math.tanh(value);
}
-
+
+ @Override
+ protected double applyDerivative(double value) {
+ return 1 - (value * value);
+ }
+
public String getName() {
return NAME;
}
diff --git a/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java b/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java
index f3093f1..0d5e941 100644
--- a/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java
+++ b/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java
@@ -17,11 +17,12 @@
*/
public class FileReaderAndWriter {
- public static void writeToFile(NeuralNetwork nn){
+ private static final Gson GSON = getGsonBuilder().create();
+
+ public static void writeToFile(NeuralNetwork nn) {
try {
FileWriter file = new FileWriter("nn_data.json");
- Gson gson = getGsonBuilder().create();
- String nnData = gson.toJson(nn);
+ String nnData = GSON.toJson(nn);
file.write(nnData);
file.flush();
@@ -34,9 +35,8 @@ public static NeuralNetwork readFromFile() {
NeuralNetwork nn = null;
try {
- Gson gson = getGsonBuilder().create();
JsonReader jsonReader = new JsonReader(new FileReader("nn_data.json"));
- nn = gson.fromJson(jsonReader, NeuralNetwork.class);
+ nn = GSON.fromJson(jsonReader, NeuralNetwork.class);
} catch (IOException e) {
e.printStackTrace();
}
@@ -45,7 +45,7 @@ public static NeuralNetwork readFromFile() {
}
// Get a GsonBuilder-object with all the needed TypeAdapters added
- private static GsonBuilder getGsonBuilder(){
+ private static GsonBuilder getGsonBuilder() {
GsonBuilder gsonBuilder = new GsonBuilder();
gsonBuilder.registerTypeAdapter(ActivationFunction.class, new InterfaceAdapter());
diff --git a/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java b/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java
index 9b56441..44684c0 100644
--- a/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java
+++ b/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java
@@ -10,10 +10,11 @@
import java.lang.reflect.Type;
-// This class is needed to make the interfaces/ abstract classes used in this project serializable/ deserializable
-// so that they can be converted to JSON or from JSON by Google Gson-library
-// The solution was found here:
-// https://stackoverflow.com/questions/4795349/how-to-serialize-a-class-with-an-interface/9550086#9550086
+/** This class is needed to make the interfaces/ abstract classes used in this project serializable/ deserializable
+ * so that they can be converted to JSON or from JSON by Google Gson-library
+ *
+ * The solution was found here:
+ * (link) */
public class InterfaceAdapter implements JsonSerializer, JsonDeserializer {
@Override
diff --git a/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java b/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java
index bd71708..fb26545 100644
--- a/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java
+++ b/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java
@@ -10,13 +10,13 @@
*/
public class MatrixUtilities {
- // Converts a 2D array into a SimpleMatrix
+ /** Converts a 2D array into a SimpleMatrix */
public static SimpleMatrix arrayToMatrix(double[] i) {
double[][] input = {i};
return new SimpleMatrix(input).transpose();
}
- // Converts a SimpleMatrix into a 2D array
+ /** Converts a SimpleMatrix into a 2D array */
public static double[][] matrixTo2DArray(SimpleMatrix i) {
double[][] result = new double[i.numRows()][i.numCols()];
@@ -28,7 +28,7 @@ public static double[][] matrixTo2DArray(SimpleMatrix i) {
return result;
}
- // Returns one specific column of a matrix as a 1D array
+ /** Returns one specific column of a matrix as a 1D array */
public static double[] getColumnFromMatrixAsArray(SimpleMatrix data, int column) {
double[] result = new double[data.numRows()];
@@ -39,7 +39,7 @@ public static double[] getColumnFromMatrixAsArray(SimpleMatrix data, int column)
return result;
}
- // Merge two matrices and return a new one
+ /** Merge two matrices and return a new one */
public static SimpleMatrix mergeMatrices(SimpleMatrix matrixA, SimpleMatrix matrixB, double probability) {
if (matrixA.numCols() != matrixB.numCols() || matrixA.numRows() != matrixB.numRows()) {
throw new WrongDimensionException();
diff --git a/src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java b/src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java
new file mode 100644
index 0000000..1a95741
--- /dev/null
+++ b/src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java
@@ -0,0 +1,36 @@
+package basicneuralnetwork;
+
+import basicneuralnetwork.activationfunctions.SigmoidActivationFunction;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+public class NeuralNetworkBuilderTest {
+
+ int inputNodes = 2;
+ int hiddenLayers = 3;
+ int hiddenNodes = 4;
+ int outputNodes = 5;
+ String activationFunction = SigmoidActivationFunction.NAME;
+ double learningRate = 0.5;
+
+ @Test
+ public void builderTest() {
+ NeuralNetwork nn = new NeuralNetworkBuilder()
+ .setInputNodes(inputNodes)
+ .setHiddenLayers(hiddenLayers)
+ .setHiddenNodes(hiddenNodes)
+ .setOutputNodes(outputNodes)
+ .setActivationFunction(activationFunction)
+ .setLearningRate(learningRate)
+ .create();
+
+ assertEquals(inputNodes, nn.getInputNodes());
+ assertEquals(hiddenLayers, nn.getHiddenLayers());
+ assertEquals(hiddenNodes, nn.getHiddenNodes());
+ assertEquals(outputNodes, nn.getOutputNodes());
+ assertEquals(nn.getActivationFunctionName(), activationFunction);
+ assertEquals(nn.getLearningRate(), learningRate, 0.0);
+ }
+
+}
\ No newline at end of file