Java nets

A first look at neural networks with Java

Disclaimer: This is a port of Francois Chollet’s notebooks to Java. All code is written by us, most of the text is taken from Francois’ original version.

In this chapter we’ll give you a first glimpse at what a neural network is and how to run one using Java. We use beakerx notebook extensions for jupyter to run Java code interactively in notebook cells. Since you’re going to work with DeepLearning4j (DL4J) in this notebook, we first need to add a JAR with all the dependencies.

%classpath add jar ./pydl4j-1.0.0-SNAPSHOT-bin.jar

To see that installing DL4J dependencies actually worked, let’s see if we can validate that we’re using DL4J version 1.0.0-SNAPSHOT.

import java.util.List;
import org.nd4j.versioncheck.*;

List<VersionInfo> versions = VersionCheck.getVersionInfos();
System.out.println(versions.get(0).getBuildVersion());

1.0.0-SNAPSHOT

The last step before we can really get started is to define a static Java class that keeps track of all the state we use in this section. The reason we do this is to keep state across notebook cells. In particular, for a neural network we need

Don’t worry if that doesn’t make a lot of sense to you yet, we’ll explain all of this in more detail soon. First, let’s define the static Demo class that will store model and data for this session.

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import java.lang.ClassLoader;

public class Demo {
    public static MultiLayerNetwork model;
    public static DataSetIterator trainData;
    public static DataSetIterator testData;
    public static ClassLoader loader = Nd4j.class.getClassLoader();
}

The MNIST dataset comes pre-loaded in DL4J, in the form of two iterators. Each iterator consists of so called DataSets. A DataSet is a handy abstraction for pairs of features and labels.

Your model will learn from trainData and then be tested on testData. Your MNIST images are encoded as ND4J arrays, and the labels are simply ND4J arrays of digits, ranging from 0 to 9.

Let’s have a look at the training and test data:

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

int batchSize = 64; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
Demo.trainData = new MnistDataSetIterator(batchSize, true, rngSeed);
Demo.testData = new MnistDataSetIterator(batchSize, false, rngSeed);
import org.nd4j.linalg.dataset.api.DataSet;
import java.util.Arrays;
Thread.currentThread().setContextClassLoader(Demo.loader);

DataSet batch = Demo.trainData.next(); 
System.out.println(Arrays.toString(batch.getFeatures().shape()));

[64, 784]

Our workflow will be as follow: first we will present our neural network with the training data trainData. The network will then learn to associate images and labels. Finally, we will ask the network to produce predictions for the images in testData, and we will verify if these predictions match the labels from testData.

Let’s build our network – again, remember that you aren’t supposed to understand everything about this example just yet.

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
Thread.currentThread().setContextClassLoader(Demo.loader);

final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // number of output classes
int batchSize = 64; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
double rate = 0.0015; // learning rate

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(rngSeed)
    .activation(Activation.RELU)
    .weightInit(WeightInit.XAVIER)
    .updater(new Nesterovs(rate, 0.98))
    .l2(rate * 0.005) // regularize learning model
    .list()
    .layer(new DenseLayer.Builder().nIn(numRows * numColumns).nOut(500).build())
    .layer(new DenseLayer.Builder().nIn(500).nOut(100).build())
    .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX).nIn(100).nOut(outputNum).build())
    .pretrain(false).backprop(true) //use backpropagation to adjust weights
    .build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));  //print the score with every iteration
Demo.model = model;

The core building block of neural networks is the “layer”, a data-processing module which you can conceive as a “filter” for data. Some data comes in, and comes out in a more useful form. Precisely, layers extract representations out of the data fed into them – hopefully representations that are more meaningful for the problem at hand. Most of deep learning really consists of chaining together simple layers which will implement a form of progressive “data distillation”. A deep learning model is like a sieve for data processing, made of a succession of increasingly refined data filters – the “layers”.

Here our network consists of a sequence of two Dense layers, which are densely-connected (also called “fully-connected”) neural layers. The second (and last) layer is a 10-way “softmax” layer, which means it will return an array of 10 probability scores (summing to 1). Each score will be the probability that the current digit image belongs to one of our 10 digit classes.

To make our network ready for training, we need to pick three more things, as part of “compilation” step:

The exact purpose of the loss function and the optimizer will be made clear throughout the next two chapters.

We are now ready to train our network, which in Keras is done via a call to the fit method of the network: we “fit” the model to its training data.

int numEpochs = 1; // number of epochs to perform

for( int i=0; i<numEpochs; i++ ){
    Demo.model.fit(Demo.trainData);
}
import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.api.ndarray.INDArray;


int outputNum = 10; // number of output classes
Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes

while(Demo.testData.hasNext()){
    DataSet next = Demo.testData.next();
    INDArray output = Demo.model.output(next.getFeatures()); //get the networks prediction
    eval.eval(next.getLabels(), output); //check the prediction against the true class
}

System.out.println(eval.stats());
========================Evaluation Metrics========================
    # of classes:    10
    Accuracy:        0.0970
    Precision:       0.0718	(1 class excluded from average)
    Recall:          0.0971
    F1 Score:        0.0757	(1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)

Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]

=========================Confusion Matrix=========================
    0   1   2   3   4   5   6   7   8   9
-----------------------------------------
    273   0 107  62  15   0  84 236  32 171 | 0 = 0
    49   0  30 307 286   0   0  18   0 445 | 1 = 1
    74   0  56 330 216   1   8 182  51 114 | 2 = 2
    160   0 331 125  82   0  15   5   3 289 | 3 = 3
    269   0  13  36  71   0 114  84   9 386 | 4 = 4
    223   0 147  81  50   0 135  44  36 176 | 5 = 5
    149   0  33  99  83   0   6 127   7 454 | 6 = 6
    242   0 154  79 122   0   3  28  15 385 | 7 = 7
    124   0  28 402  65   0  40  52  24 239 | 8 = 8
    381   0  46  57  61   0  41  19  17 387 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================