Java nets
A first look at neural networks with Java
Disclaimer: This is a port of Francois Chollet’s notebooks to Java. All code is written by us, most of the text is taken from Francois’ original version.
In this chapter we’ll give you a first glimpse at what a neural network is and how to run one using Java. We use beakerx notebook extensions for jupyter to run Java code interactively in notebook cells. Since you’re going to work with DeepLearning4j (DL4J) in this notebook, we first need to add a JAR with all the dependencies.
%classpath add jar ./pydl4j-1.0.0-SNAPSHOT-bin.jar
To see that installing DL4J dependencies actually worked, let’s see if we can validate that we’re using DL4J version 1.0.0-SNAPSHOT.
import java.util.List;
import org.nd4j.versioncheck.*;
List<VersionInfo> versions = VersionCheck.getVersionInfos();
System.out.println(versions.get(0).getBuildVersion());
1.0.0-SNAPSHOT
The last step before we can really get started is to define a static Java class that keeps track of all the state we use in this section. The reason we do this is to keep state across notebook cells. In particular, for a neural network we need
- The neural network (
model
) that we will train. - Two types of data, one to let the neural network learn (
trainData
) and one to validate the results (testData
).
Don’t worry if that doesn’t make a lot of sense to you yet, we’ll explain all of this in more detail soon. First, let’s define the static Demo
class that will store model and data for this session.
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.lang.ClassLoader;
public class Demo {
public static MultiLayerNetwork model;
public static DataSetIterator trainData;
public static DataSetIterator testData;
public static ClassLoader loader = Nd4j.class.getClassLoader();
}
The MNIST dataset comes pre-loaded in DL4J, in the form of two iterators. Each iterator consists of so called DataSet
s. A DataSet
is a handy abstraction for pairs of features and labels.
Your model will learn from trainData
and then be tested on testData
. Your MNIST images are encoded as ND4J arrays, and the labels are simply ND4J arrays of digits, ranging from 0 to 9.
Let’s have a look at the training and test data:
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
int batchSize = 64; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
Demo.trainData = new MnistDataSetIterator(batchSize, true, rngSeed);
Demo.testData = new MnistDataSetIterator(batchSize, false, rngSeed);
import org.nd4j.linalg.dataset.api.DataSet;
import java.util.Arrays;
Thread.currentThread().setContextClassLoader(Demo.loader);
DataSet batch = Demo.trainData.next();
System.out.println(Arrays.toString(batch.getFeatures().shape()));
[64, 784]
Our workflow will be as follow: first we will present our neural network with the training data trainData
. The
network will then learn to associate images and labels. Finally, we will ask the network to produce predictions for the images in testData
, and we will verify if these predictions match the labels from testData
.
Let’s build our network – again, remember that you aren’t supposed to understand everything about this example just yet.
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
Thread.currentThread().setContextClassLoader(Demo.loader);
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // number of output classes
int batchSize = 64; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
double rate = 0.0015; // learning rate
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(rate, 0.98))
.l2(rate * 0.005) // regularize learning model
.list()
.layer(new DenseLayer.Builder().nIn(numRows * numColumns).nOut(500).build())
.layer(new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(100).nOut(outputNum).build())
.pretrain(false).backprop(true) //use backpropagation to adjust weights
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100)); //print the score with every iteration
Demo.model = model;
The core building block of neural networks is the “layer”, a data-processing module which you can conceive as a “filter” for data. Some data comes in, and comes out in a more useful form. Precisely, layers extract representations out of the data fed into them – hopefully representations that are more meaningful for the problem at hand. Most of deep learning really consists of chaining together simple layers which will implement a form of progressive “data distillation”. A deep learning model is like a sieve for data processing, made of a succession of increasingly refined data filters – the “layers”.
Here our network consists of a sequence of two Dense
layers, which are densely-connected (also called “fully-connected”) neural layers. The second (and last) layer is a 10-way “softmax” layer, which means it will return an array of 10 probability scores (summing to 1). Each score will be the probability that the current digit image belongs to one of our 10 digit classes.
To make our network ready for training, we need to pick three more things, as part of “compilation” step:
- A loss function: the is how the network will be able to measure how good a job it is doing on its training data, and thus how it will be able to steer itself in the right direction.
- An optimizer: this is the mechanism through which the network will update itself based on the data it sees and its loss function.
- Metrics to monitor during training and testing. Here we will only care about accuracy (the fraction of the images that were correctly classified).
The exact purpose of the loss function and the optimizer will be made clear throughout the next two chapters.
We are now ready to train our network, which in Keras is done via a call to the fit
method of the network:
we “fit” the model to its training data.
int numEpochs = 1; // number of epochs to perform
for( int i=0; i<numEpochs; i++ ){
Demo.model.fit(Demo.trainData);
}
import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.api.ndarray.INDArray;
int outputNum = 10; // number of output classes
Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
while(Demo.testData.hasNext()){
DataSet next = Demo.testData.next();
INDArray output = Demo.model.output(next.getFeatures()); //get the networks prediction
eval.eval(next.getLabels(), output); //check the prediction against the true class
}
System.out.println(eval.stats());
========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.0970
Precision: 0.0718 (1 class excluded from average)
Recall: 0.0971
F1 Score: 0.0757 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
-----------------------------------------
273 0 107 62 15 0 84 236 32 171 | 0 = 0
49 0 30 307 286 0 0 18 0 445 | 1 = 1
74 0 56 330 216 1 8 182 51 114 | 2 = 2
160 0 331 125 82 0 15 5 3 289 | 3 = 3
269 0 13 36 71 0 114 84 9 386 | 4 = 4
223 0 147 81 50 0 135 44 36 176 | 5 = 5
149 0 33 99 83 0 6 127 7 454 | 6 = 6
242 0 154 79 122 0 3 28 15 385 | 7 = 7
124 0 28 402 65 0 40 52 24 239 | 8 = 8
381 0 46 57 61 0 41 19 17 387 | 9 = 9
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================