Understanding and Implementing Neural Networks in Java from Scratch 💻 Learning the popular concept in the most 💪strongly typed language
Explore the docs »
View Demo
·
Report Bug
·
Request Feature
A simple Neural Network
class completely written in Java from scratch without using any external libraries.
A Neural Network is a type of computational system which represents the human brain in a smaller manner. The Neurons
are connected with Synapses
, these concepts were taken from the 🧠. The system simulates the learning process of the brain by adjusting the Weights
of the synapses by the process of Backpropagation
.
Here's why you should consider using it:
- Easy to use and test
- Logging features
- Customizable
I also developed a even more feature rich JavaScript library
which can be found here Lite Neural Network
- Place
NeuralNetwork.java
in your project directory - Instantiate the neuralnetwork object
- Process the data in proper format
- Train the model
- Test the model
For full explanation, please refer to the Article
// In the sample we are creating a neural network with 2 input features and 1 output with 10 hidden nodes
// default constructor
NeuralNetwork nn = new NeuralNetwork(2, 10, 1);
// constructor with custom learning rate
NeuralNetwork nn_custom_lr = new NeuralNetwork(2, 10, 1, 0.01);
// constructor with multi-threading set to true
NeuralNetwork nn_with_multithreading = new NeuralNetwork(2, 10, 1, true);
// constructor with multi-threading set to true and custom learning rate
NeuralNetwork nn_custom_lr_with_multithreading = new NeuralNetwork(2, 10, 1, 0.01, true);
// As we have 2 input features we created the 2d double array for X and 1d double array for Y
double[][] X = { { 0, 0 }, { 1, 0 }, { 0, 1 }, { 1, 1 } };
double[][] Y = { { 0 }, { 1 }, { 1 }, { 0 } };
// Here we pass the data and specify the number of epochs
nn.fit(X, Y, 500); // silent learn
nn.fit(X,Y,500,0); // logging set to 0, shows training time and average error
nn.fit(X,Y,500,1); // logging set to 1, shows logs for each epoch
// Here we used 4 data-points to test the model, we iterate and pass each data-point to the model
List<Double> output;
double[][] input = { { 0, 0 }, { 0, 1 }, { 1, 0 }, { 1, 1 } };
for (double d[] : input) {
output = nn.predict(d);
System.out.println(output.toString());
}
See the open issues for a list of proposed features (and known issues).
Contributions are what make the open source community such an amazing place to be learn, inspire, and create. Any contributions you make are greatly appreciated.
- Fork the Project
- Create your Feature Branch (
git checkout -b feature/AmazingFeature
) - Commit your Changes (
git commit -m 'Add some AmazingFeature'
) - Push to the Branch (
git push origin feature/AmazingFeature
) - Open a Pull Request
- Basic Library
- Documentation
- Explanation Article
- Multi-Threading Support
- Interface to load data easily
- Multiple layer architecture
- Improve performance
- Unit Testing
- Production
Distributed under the MIT License. See LICENSE
for more information.
Your Name - @suyashysonawane - [email protected]
Project Link: https://github.com/suyashsonawane/JavaNet