diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index b74df2d2c8a3..1f7ce29c038a 100644 --- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -23,7 +23,9 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; +import org.junit.rules.DisableOnDebug; import org.junit.rules.TestName; +import org.junit.rules.TestRule; import org.junit.rules.Timeout; import org.nd4j.common.base.Preconditions; import org.nd4j.common.config.ND4JSystemProperties; @@ -48,7 +50,7 @@ public abstract class BaseDL4JTest { @Rule public TestName name = new TestName(); @Rule - public Timeout timeout = Timeout.millis(getTimeoutMilliseconds()); + public TestRule timeout = new DisableOnDebug(Timeout.millis(getTimeoutMilliseconds())); protected long startTime; protected int threadCountBefore; diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 90c88d4c3d86..5608c066d29e 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -22,6 +22,7 @@ deeplearning4j-parent 1.0.0-SNAPSHOT + @@ -166,6 +167,46 @@ + + + + org.apache.maven.wagon + wagon-http + 2.9 + + + org.kuali.maven.wagons + maven-s3-wagon + 1.2.1 + + + + + + + maven-surefire-plugin + true + + true + false + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g + + + *.java + **/*.java + + + + listener + org.deeplearning4j.samediff.ToSameDiffTests + + + + + + + + test-nd4j-native diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index c004c5c0de1b..1d0147dedc3c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -16,14 +16,42 @@ package org.deeplearning4j; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.BufferedOutputStream; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import lombok.NonNull; import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.layers.BaseOutputLayer; +import org.deeplearning4j.nn.layers.LossLayer; import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer; import org.deeplearning4j.nn.layers.normalization.BatchNormalization; @@ -31,7 +59,10 @@ import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; @@ -40,14 +71,7 @@ import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; - -import java.io.*; -import java.lang.reflect.Field; -import java.util.List; -import java.util.Random; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import org.nd4j.linalg.lossfunctions.ILossFunction; public class TestUtils { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index 709d889017cc..19a0b8f184e2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -28,6 +28,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.TestToSameDiff; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -106,6 +108,7 @@ public void testSelfAttentionLayer() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); } } } @@ -167,6 +170,7 @@ public void testLearnedSelfAttentionLayer() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); } } } @@ -322,6 +326,8 @@ public void testRecurrentAttentionLayer() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(100)); assertTrue(name, gradOK); + + ToSameDiffTests.testToSameDiff(net, in, labels); } } } @@ -385,6 +391,7 @@ public void testAttentionVertex() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100)); assertTrue(name, gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); } } } @@ -447,6 +454,7 @@ public void testAttentionVertexSameInput() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in}) .labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null)); assertTrue(name, gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 081abd45da12..e3e37093a0da 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -30,11 +30,11 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; @@ -42,8 +42,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.linalg.profiler.OpProfiler; -import org.nd4j.linalg.profiler.ProfilerConfig; import java.util.Arrays; import java.util.HashSet; @@ -104,6 +102,7 @@ public void testGradient2dSimple() { .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -149,6 +148,7 @@ public void testGradientCnnSimple() { .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -251,6 +251,7 @@ public void testGradientBNWithCNNandSubsampling() { .labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -355,6 +356,7 @@ public void testGradientDense() { .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -399,6 +401,7 @@ public void testGradient2dFixedGammaBeta() { .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -444,6 +447,7 @@ public void testGradientCnnFixedGammaBeta() { .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -489,6 +493,7 @@ public void testBatchNormCompGraphSimple() { assertTrue(gradOK); TestUtils.testModelSerialization(net); + ToSameDiffTests.testToSameDiff(net, input, labels); } } @@ -587,6 +592,7 @@ public void testGradientBNWithCNNandSubsamplingCompGraph() { assertTrue(gradOK); TestUtils.testModelSerialization(net); + ToSameDiffTests.testToSameDiff(net, input, labels); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index 06fe0cf350a6..834d208d83e9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -27,8 +27,8 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.deeplearning4j.util.Convolution1DUtils; -import org.deeplearning4j.util.ConvolutionUtils; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -122,6 +122,7 @@ public void testCnn1DWithLocallyConnected1D() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } @@ -202,6 +203,7 @@ public void testCnn1DWithCropping1D() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -285,6 +287,7 @@ public void testCnn1DWithZeroPadding1D() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -362,6 +365,7 @@ public void testCnn1DWithSubsampling1D() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -423,6 +427,7 @@ public void testCnn1dWithMasking(){ .labels(label).inputMask(fm)); assertTrue(s, gradOK); + ToSameDiffTests.testToSameDiff(net, f, label); TestUtils.testModelSerialization(net); //TODO also check that masked step values don't impact forward pass, score or gradients @@ -518,6 +523,7 @@ public void testCnn1Causal() { .labels(label).inputMask(fm)); assertTrue(s, gradOK); + ToSameDiffTests.testToSameDiff(net, f, label); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 30cc783da458..f51eff628105 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -29,9 +29,9 @@ import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -159,6 +159,7 @@ public void testCnn3DPlain() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -262,6 +263,7 @@ public void testCnn3DZeroPadding() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } @@ -352,6 +354,7 @@ public void testCnn3DPooling() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -442,6 +445,7 @@ public void testCnn3DUpsampling() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -541,6 +545,7 @@ public void testCnn3DCropping() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } @@ -632,6 +637,7 @@ public void testDeconv3d() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index c303cc594498..1b617a2f7ab2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -16,7 +16,6 @@ package org.deeplearning4j.gradientcheck; -import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; @@ -33,6 +32,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; @@ -153,6 +153,7 @@ public void testGradientCNNMLN() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -247,6 +248,8 @@ public void testGradientCNNL1L2MLN() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); + + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -309,6 +312,7 @@ public void testCnnWithSpaceToDepth() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -378,6 +382,7 @@ public void testCnnWithSpaceToBatch() { .labels(new INDArray[]{labels})); assertTrue(msg + " - compgraph", gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -438,6 +443,7 @@ public void testCnnWithUpsampling() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -509,6 +515,7 @@ public void testCnnWithSubsampling() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -578,6 +585,7 @@ public void testCnnWithSubsamplingV2() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -638,6 +646,8 @@ public void testCnnLocallyConnected2D() { assertTrue(msg, gradOK); + //TODO existing define method requires offline shape inference + // ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -705,6 +715,7 @@ public void testCnnMultiLayer() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -770,6 +781,7 @@ public void testCnnSamePaddingMode() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -837,6 +849,7 @@ public void testCnnSamePaddingModeStrided() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -920,6 +933,7 @@ public void testCnnZeroPaddingLayer() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -995,6 +1009,7 @@ public void testDeconvolution2D() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -1068,6 +1083,7 @@ public void testSeparableConv2D() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -1152,6 +1168,7 @@ public void testCnnDilated() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -1227,6 +1244,7 @@ public void testCropping2DLayer() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -1298,6 +1316,7 @@ public void testDepthwiseConv2D() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java index e604c594dc1a..0ff5ce3e38ac 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitDistribution; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; @@ -39,8 +40,6 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; -import java.util.Random; - public class CapsnetGradientCheckTest extends BaseDL4JTest { @Override @@ -114,6 +113,7 @@ public void testCapsNet() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java index c4f9d2843af0..316a79912dfc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -141,6 +142,7 @@ public void testDropoutGradient() { false, -1, null, 12345); //Last arg: ensures RNG is reset at each iter... otherwise will fail due to randomness! assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, f, l); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index 742052a42b54..e2bea569a386 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -106,6 +107,7 @@ public void testRNNGlobalPoolingBasicMultiLayer() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -165,6 +167,7 @@ public void testCnnGlobalPoolingBasicMultiLayer() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -225,6 +228,7 @@ public void testLSTMWithMasking() { .labels(labels).inputMask(featuresMask)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -308,6 +312,7 @@ public void testCnnGlobalPoolingMasking() { .labels(labels).inputMask(inputMask)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 2c6f8843e375..a8022d0753a4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -27,11 +27,13 @@ import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -136,6 +138,7 @@ public void testMinibatchApplication() { String msg = "testMinibatchApplication() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, ds.getFeatures(), ds.getLabels()); TestUtils.testModelSerialization(mln); } @@ -216,6 +219,7 @@ public void testGradientMLP2LayerIrisSimple() { String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -311,6 +315,8 @@ public void testGradientMLP2LayerIrisL1L2Simple() { + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1; assertTrue(msg, gradOK); + + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -394,6 +400,7 @@ public void testEmbeddingLayerSimple() { String msg = "testEmbeddingLayerSimple"; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } @@ -446,6 +453,7 @@ public void testAutoEncoder() { .activation(afn).build()) .layer(1, new OutputLayer.Builder(lf).nIn(3).nOut(3) .activation(outputActivation).build()) + .setInputType(InputType.inferInputType(input)) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -482,6 +490,7 @@ public void testAutoEncoder() { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -553,6 +562,7 @@ public void elementWiseMultiplicationLayerTest(){ assertTrue(msg, gradOK); TestUtils.testModelSerialization(netGraph); + ToSameDiffTests.testToSameDiff(netGraph, features, labels); } } @@ -579,6 +589,7 @@ public void testEmbeddingSequenceLayer(){ .layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH) .dataFormat(seqOutputFormat) .lossFunction(LossFunction.MSE).build()) + .setInputType(InputType.recurrent(3, 6, RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -606,6 +617,7 @@ public void testEmbeddingSequenceLayer(){ boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(label).inputMask(fMask)); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, in, label); TestUtils.testModelSerialization(net); @@ -705,6 +717,7 @@ public void testGradientWeightDecay() { + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1; assertTrue(msg, gradOK1); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -789,6 +802,7 @@ public void testGradientMLP2LayerIrisLayerNorm() { String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", layerNorm=" + layerNorm; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index ac3c3deea8ef..1c5a46431988 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -32,12 +32,10 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; -import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -48,7 +46,6 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.Arrays; import java.util.Map; import java.util.Random; @@ -116,6 +113,7 @@ public void testBasicIris() { String msg = "testBasicIris()"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } @Test @@ -167,6 +165,7 @@ public void testBasicIrisWithMerging() { String msg = "testBasicIrisWithMerging()"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } @Test @@ -224,6 +223,7 @@ public void testBasicIrisWithElementWiseNode() { String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } } @@ -284,6 +284,7 @@ public void testBasicIrisWithElementWiseNodeInputSizeGreaterThanTwo() { String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } } @@ -331,6 +332,7 @@ public void testElementWiseVertexBroadcast(){ .labels(new INDArray[]{labels})); assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, in, labels); } } } @@ -383,6 +385,7 @@ public void testCnnDepthMerge() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{input}, new INDArray[]{labels}, new InputType[]{InputType.convolutional(6, 6, 2, format)}); } } @@ -443,6 +446,7 @@ public void testRNNWithMerging() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } } @@ -480,6 +484,7 @@ public void testLSTMWithSubset() { String msg = "testLSTMWithSubset()"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } @Test @@ -528,6 +533,7 @@ public void testLSTMWithLastTimeStepVertex() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } @Test @@ -577,6 +583,7 @@ public void testLSTMWithDuplicateToTimeSeries() { String msg = "testLSTMWithDuplicateToTimeSeries()"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{input1, input2}, new INDArray[]{labels}); } @Test @@ -636,6 +643,7 @@ public void testLSTMWithReverseTimeSeriesVertex() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, labels); } @Test @@ -679,6 +687,7 @@ public void testMultipleInputsLayer() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, inputs, new INDArray[]{out}); } } @@ -719,6 +728,7 @@ public void testMultipleOutputsLayer() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, out); } } @@ -765,6 +775,7 @@ public void testMultipleOutputsMergeVertex() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, new INDArray[]{out}); } } @@ -816,6 +827,7 @@ public void testMultipleOutputsMergeCnn() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, input, out); } } @@ -885,6 +897,7 @@ public void testBasicIrisTripletStackingL2Loss() { String msg = "testBasicIrisTripletStackingL2Loss()"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{pos, anc, neg}, new INDArray[]{labels}); } @@ -945,6 +958,7 @@ public void testBasicCenterLoss() { assertTrue(msg, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, example, labels); } } } @@ -1009,6 +1023,7 @@ public void testCnnPoolCenterLoss() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, example, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, example, labels); TestUtils.testModelSerialization(net); } } @@ -1059,6 +1074,7 @@ public void testBasicL2() { assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels}); } } @@ -1117,6 +1133,7 @@ public void testBasicStackUnstack() { assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2}); } } @@ -1175,6 +1192,7 @@ public void testBasicStackUnstackDebug() { assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2}); } } @@ -1240,6 +1258,7 @@ public void testBasicStackUnstackVariableLengthTS() { assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2}); } } @@ -1296,6 +1315,7 @@ public void testBasicTwoOutputs() { .labels(new INDArray[]{labels1, labels2})); assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2}); } } @@ -1339,6 +1359,7 @@ public void testL2NormalizeVertex2d() { assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, in1, labels1); } } @@ -1388,6 +1409,7 @@ public void testL2NormalizeVertex4d() { assertTrue(testName, gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, in1, labels1); } } @@ -1427,5 +1449,6 @@ public void testGraphEmbeddingLayerSimple() { String msg = "testGraphEmbeddingLayerSimple"; assertTrue(msg, gradOK); TestUtils.testModelSerialization(cg); + ToSameDiffTests.testToSameDiff(cg, input, labels); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index 09afd6c2f4c7..613b5ab5944b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -137,6 +138,7 @@ public void gradientCheckMaskingOutputSimple() { String msg = "gradientCheckMaskingOutputSimple() - timeSeriesLength=" + timeSeriesLength + ", miniBatchSize=" + 1; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -186,6 +188,7 @@ public void testBidirectionalLSTMMasking() { .labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(12)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -267,6 +270,7 @@ public void testPerOutputMaskingMLP() { .labels(labels).labelMask(labelMask)); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, features, labels); TestUtils.testModelSerialization(net); } } @@ -384,6 +388,7 @@ public void testPerOutputMaskingRnn() { assertTrue(msg + " (compgraph)", gradOK); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, features, labels); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index 5a6e003f258a..9ecc62f9e88b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -96,6 +97,7 @@ public void testGradientLRNSimple() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index 2f0822b80b16..80b4bc241255 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -137,6 +138,7 @@ public void testLSTMBasicMultiLayer() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(testName, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -226,6 +228,7 @@ public void testGradientLSTMFull() { .labels(labels).subset(true).maxPerParam(128)); assertTrue(testName, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -276,6 +279,7 @@ public void testGradientLSTMEdgeCases() { boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -356,6 +360,7 @@ public void testGradientGravesBidirectionalLSTMFull() { String msg = "testGradientGravesLSTMFull() - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -405,6 +410,7 @@ public void testGradientGravesBidirectionalLSTMEdgeCases() { String msg = "testGradientGravesLSTMEdgeCases() - timeSeriesLength=" + timeSeriesLength[i] + ", miniBatchSize=" + miniBatchSize[i]; assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -460,6 +466,7 @@ public void testGradientCnnFfRnn() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) .labels(labels).subset(true).maxPerParam(32)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 88ea98ca23ea..e67969c57fa8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; @@ -226,6 +227,7 @@ public void lossFunctionGradientCheck() { } else { failed.add(testName); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -395,6 +397,7 @@ public void lossFunctionGradientCheckLossLayer() { } TestUtils.testModelSerialization(net); + ToSameDiffTests.testToSameDiff(net, input, labels); } } @@ -703,6 +706,8 @@ public void lossFunctionWeightedGradientCheck() { } else { failed.add(testName); } + + ToSameDiffTests.testToSameDiff(net, input, labels); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index cc4e7410413d..dcdf58d53395 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -121,6 +122,7 @@ public void testGradientNoBiasDenseOutput() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -178,6 +180,7 @@ public void testGradientNoBiasRnnOutput() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -240,6 +243,7 @@ public void testGradientNoBiasEmbedding() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -306,6 +310,7 @@ public void testCnnWithSubsamplingNoBias() { assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 24745fd0a1e1..222f3d497491 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -22,8 +22,10 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -146,6 +148,7 @@ public void testRnnLossLayer() { .labels(labels).labelMask(labelMask)); assertTrue(testName, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -231,11 +234,13 @@ public void testCnnLossLayer() { .convolutionMode(ConvolutionMode.Same) .list() .layer(new ConvolutionLayer.Builder().nIn(dIn).nOut(dOut).activation(Activation.TANH) + .kernelSize(3, 3) .dist(new NormalDistribution(0, 1.0)) .updater(new NoOp()).build()) .layer(new CnnLossLayer.Builder(lf) .activation(oa) .build()) + .setInputType(InputType.inferInputType(input)) .validateOutputLayerConfig(false).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -253,6 +258,7 @@ public void testCnnLossLayer() { .labels(labels).labelMask(labelMask)); assertTrue(testName, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -385,6 +391,7 @@ public void testCnn3dLossLayer() { .lossFunction(lf) .activation(oa) .build()) + .setInputType(InputType.inferInputType(input)) .validateOutputLayerConfig(false).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -402,6 +409,8 @@ public void testCnn3dLossLayer() { .labels(labels).labelMask(labelMask)); assertTrue(testName, gradOK); + //TODO known loss issue due to DL4J packing dimensions into batch + ToSameDiffTests.testToSameDiff(mln, input, null); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index e356cce1d458..0d747e61ce15 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -133,6 +134,7 @@ public void testBidirectionalWrapper() { assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); TestUtils.testModelSerialization(net); } } @@ -211,6 +213,7 @@ public void testSimpleRnn() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); TestUtils.testModelSerialization(net); } } @@ -286,6 +289,7 @@ public void testLastTimeStepLayer(){ boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(16)); assertTrue(name, gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); TestUtils.testModelSerialization(net); } } @@ -350,6 +354,7 @@ public void testTimeDistributedDense() { boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in) .labels(labels).inputMask(inMask).subset(true).maxPerParam(16)); assertTrue(name, gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index b8412a8d26a5..ac3bd773f317 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -186,6 +187,7 @@ public void testMaskLayer() { .input(input).labels(label).inputMask(inMask)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(net, input, label); TestUtils.testModelSerialization(net); } } @@ -226,6 +228,7 @@ public void testFrozenWithBackprop(){ .labels(labels).excludeParams(excludeParams)); assertTrue(gradOK); + ToSameDiffTests.testToSameDiff(net, in, labels); TestUtils.testModelSerialization(net); @@ -238,6 +241,7 @@ public void testFrozenWithBackprop(){ assertTrue(gradOKCG); TestUtils.testModelSerialization(g); + ToSameDiffTests.testToSameDiff(g, in, labels); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index 3d4a6180c5bc..cf71fb831ffb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -21,10 +21,12 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationTanH; @@ -116,6 +118,7 @@ public void testVaeAsMLP() { .dist(new NormalDistribution(0, 1)) .build()) + .setInputType(InputType.inferInputType(input)) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -135,6 +138,7 @@ public void testVaeAsMLP() { DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -184,6 +188,7 @@ public void testVaePretrain() { .reconstructionDistribution( new GaussianReconstructionDistribution(pxzAfn)) .activation(afn).build()) + .setInputType(InputType.inferInputType(input)) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -207,6 +212,7 @@ public void testVaePretrain() { RETURN_ON_FIRST_FAILURE, input, 12345); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, input, labels); TestUtils.testModelSerialization(mln); } } @@ -275,6 +281,7 @@ public void testVaePretrainReconstructionDistributions() { reconstructionDistributions[i]) .activation(Activation.TANH) .build()) + .setInputType(InputType.inferInputType(data)) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -295,6 +302,7 @@ public void testVaePretrainReconstructionDistributions() { data, 12345); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, data, null); TestUtils.testModelSerialization(mln); } } @@ -317,6 +325,7 @@ public void testVaePretrainMultipleSamples() { new GaussianReconstructionDistribution(Activation.TANH)) .numSamples(numSamples).activation(Activation.TANH) .build()) + .setInputType(InputType.inferInputType(features)) .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -337,6 +346,7 @@ public void testVaePretrainMultipleSamples() { features, 12345); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(mln, features, null); TestUtils.testModelSerialization(mln); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 47c040c1214e..9d4704b3aac9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -153,6 +154,7 @@ public void testYoloOutputLayer() { .labels(labels).subset(true).maxPerParam(100)); assertTrue(msg, gradOK); + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -261,6 +263,7 @@ public void yoloGradientCheckRealData() throws Exception { .labels(l).inputMask(null).subset(true).maxPerParam(64)); assertTrue(ok); + ToSameDiffTests.testToSameDiff(net, f, l); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java index dbef14bf27f6..7c3a901b60b4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java @@ -24,7 +24,7 @@ public class SDLossMAE extends SameDiffLoss { @Override - public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { - return sd.math.abs(labels.sub(layerInput)).mean(1); + public SDVariable defineLossArray(SameDiff sameDiff, SDVariable layerInput, SDVariable labels) { + return sameDiff.math.abs(labels.sub(layerInput)).mean(1); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java index 6edce7a499c8..5eb6b91bec4f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java @@ -24,7 +24,7 @@ public class SDLossMSE extends SameDiffLoss { @Override - public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { + public SDVariable defineLossArray(SameDiff sameDiff, SDVariable layerInput, SDVariable labels) { return labels.squaredDifference(layerInput).mean(1); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java index 0db6a13572f3..e981bb115ac0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java @@ -37,6 +37,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -100,6 +101,7 @@ public void testLayerRecurrentConstraints() throws Exception { assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -153,6 +155,7 @@ public void testLayerBiasConstraints() throws Exception { assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -205,6 +208,7 @@ public void testLayerWeightsConstraints() throws Exception { assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -265,6 +269,7 @@ public void testLayerWeightsAndBiasConstraints() throws Exception { assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -326,6 +331,7 @@ public void testLayerWeightsAndBiasSeparateConstraints() throws Exception { assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } @@ -384,6 +390,7 @@ public void testModelConstraints() throws Exception { assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 ); } + ToSameDiffTests.testToSameDiff(net, input, labels); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java index 574cb0d39c23..9f2a2cd1de63 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -144,7 +145,7 @@ public void testCalls(){ } @Data - public static class CustomDropout implements IDropout{ + public static class CustomDropout extends BaseDropout{ private List> allCalls = new ArrayList<>(); private List> allReverseCalls = new ArrayList<>(); @@ -191,6 +192,7 @@ public void testSerialization(){ MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); + ToSameDiffTests.testToSameDiff(net, null, null); TestUtils.testModelSerialization(net); ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java index 77061e3987ee..208e88295a99 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java @@ -18,8 +18,8 @@ import lombok.EqualsAndHashCode; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -28,7 +28,7 @@ * Created by Alex on 09/09/2016. */ @EqualsAndHashCode -public class MyCustomPreprocessor implements InputPreProcessor { +public class MyCustomPreprocessor extends BaseInputPreProcessor { @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { @@ -41,7 +41,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w } @Override - public InputPreProcessor clone() { + public BaseInputPreProcessor clone() { return new MyCustomPreprocessor(); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java index 0ebc598bc68b..444373ce3d94 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java @@ -32,6 +32,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -77,6 +78,7 @@ public void testWeightNoiseConfigJson() { assertEquals(wn, ((BaseLayer) net.getLayer(2).conf().getLayer()).getWeightNoise()); TestUtils.testModelSerialization(net); + ToSameDiffTests.testToSameDiff(net, null, null); ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() @@ -97,6 +99,7 @@ public void testWeightNoiseConfigJson() { assertEquals(wn, ((BaseLayer) graph.getLayer(2).conf().getLayer()).getWeightNoise()); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, Nd4j.create(1,10), Nd4j.create(1,10)); graph.fit(new DataSet(Nd4j.create(1,10), Nd4j.create(1,10))); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index beec5cf2042c..ecc731a85221 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -20,6 +20,7 @@ import org.deeplearning4j.nn.conf.preprocessor.*; import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer; import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; @@ -1024,6 +1025,8 @@ public void testEmbeddingDtypes() { logUsedClasses(net); + ToSameDiffTests.testToSameDiff(net, input, label); + //Now, test mismatched dtypes for input/labels: for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { INDArray in2 = input.castTo(inputLabelDtype); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index b0cc17376248..b20517e316c9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -56,6 +56,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.deeplearning4j.util.ModelSerializer; import org.junit.*; import org.junit.rules.TemporaryFolder; @@ -1471,6 +1472,8 @@ public void testZeroParamNet() throws Exception { ComputationGraph net2 = TestUtils.testModelSerialization(net); INDArray out2 = net2.outputSingle(ds.getFeatures()); assertEquals(out, out2); + // labels are wrong size, would need to be [batch, 1] fr LossLayer. Convolutional input is handled via preprocessor here, not CnnLossLayer. + ToSameDiffTests.testToSameDiff(net, ds.getFeatures(), null); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java index 23c4421e57fe..dff0023ab2e4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java @@ -32,6 +32,9 @@ import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.TestToSameDiff; +import org.deeplearning4j.samediff.ToSameDiffTests; +import org.deeplearning4j.util.ToSameDiffUtils; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -144,6 +147,8 @@ public void testDropoutLayerWithoutTraining() throws Exception { assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(4)); assertEquals(actTestIntegrated.get(1), actTestSeparate.get(2)); assertEquals(actTestIntegrated.get(2), actTestSeparate.get(4)); + + ToSameDiffTests.testToSameDiff(netIntegrated, in, null); } @Test @@ -297,5 +302,7 @@ public void testDropoutLayerWithConvMnist() throws Exception { List actTestSeparate = netSeparate.feedForward(false); assertEquals(actTestIntegrated.get(1), actTestSeparate.get(1)); assertEquals(actTestIntegrated.get(2), actTestSeparate.get(3)); + ToSameDiffTests.testToSameDiff(netIntegrated, next.getFeatures().dup(), next.getLabels().dup()); + ToSameDiffTests.testToSameDiff(netSeparate, next.getFeatures().dup(), next.getLabels().dup()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java index 65d204964fc3..669bb7484cf6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -372,5 +373,7 @@ public void testFrozenLayerInstantiationCompGraph() { INDArray out3 = net3.outputSingle(input); assertEquals(out2, out3); + + ToSameDiffTests.testToSameDiff(net2, input, null); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java index 200f55071d24..54a9e59220c0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -98,6 +99,8 @@ public void testFrozenWithBackpropLayerInstantiation() { INDArray out3 = net3.output(input); assertEquals(out2, out3); + + ToSameDiffTests.testToSameDiff(net1, input, null); } @Test @@ -153,6 +156,7 @@ public void testFrozenLayerInstantiationCompGraph() { INDArray out3 = net3.outputSingle(input); assertEquals(out2, out3); + ToSameDiffTests.testToSameDiff(net2, input, null); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index 2d746a1372ec..acf736a50651 100755 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; @@ -32,6 +33,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -332,6 +334,7 @@ public void testCompareRnnOutputRnnLoss(){ assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); assertEquals(mln.score(), mln2.score(), 1e-6); + ToSameDiffTests.testToSameDiff(mln, in, labels); TestUtils.testModelSerialization(mln); } @@ -361,6 +364,7 @@ public void testCnnLossLayer(){ .layer(new CnnLossLayer.Builder(LossFunction.MSE) .activation(a) .build()) + .setInputType(InputType.convolutional(5, 5, 4)) .build(); MultiLayerConfiguration conf2 = @@ -421,6 +425,7 @@ public void testCnnLossLayer(){ assertArrayEquals(new long[]{2, 1}, s.shape()); assertEquals(s.getDouble(0), s.getDouble(1), 1e-6); + ToSameDiffTests.testToSameDiff(mln, in2, labels2); TestUtils.testModelSerialization(mln); } } @@ -515,6 +520,8 @@ public void testCnnLossLayerCompGraph(){ assertEquals(s.getDouble(0), s.getDouble(1), 1e-6); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, in, labels); + ToSameDiffTests.testToSameDiff(graph2, in2, labels2); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java index 76d14d47d46f..f5dea74d58a1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java @@ -27,16 +27,20 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.deeplearning4j.util.ConvolutionUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -924,6 +928,14 @@ public static void testHelper(TestCase tc) { assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2)); } + //TODO LocallyConnected NPEs because of the lack of SDVariable shapes + if(!(tc.net1.getnLayers() > 1 && tc.net1.getLayer(1).getConfig() instanceof LocallyConnected2D)) { + ToSameDiffTests.testToSameDiff(tc.net1, inNCHW, null); + ToSameDiffTests.testToSameDiff(tc.net2, inNCHW, null); + ToSameDiffTests.testToSameDiff(tc.net3, inNHWC, null); + ToSameDiffTests.testToSameDiff(tc.net4, inNHWC, null); + } + } private static List differentGrads(Gradient g1, Gradient g2){ @@ -943,7 +955,7 @@ private static List differentGrads(Gradient g1, Gradient g2){ //Converts NHWC to NCHW activations @EqualsAndHashCode - private static class NHWCToNCHWPreprocessor implements InputPreProcessor { + private static class NHWCToNCHWPreprocessor extends BaseInputPreProcessor { @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { @@ -956,7 +968,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w } @Override - public InputPreProcessor clone() { + public BaseInputPreProcessor clone() { return this; } @@ -970,6 +982,11 @@ public InputType getOutputType(InputType inputType) { public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { return null; } + + @Override + public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return input.permute(0, 3, 1, 2); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java index 1d354ef519de..3595ec327abf 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java @@ -72,6 +72,7 @@ public void testConvolutionLayerSetup() { builder.setInputType(InputType.convolutionalFlat(28, 28, 1)); MultiLayerConfiguration completed = complete().build(); MultiLayerConfiguration test = builder.build(); + test.setInputType(null); assertEquals(completed, test); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 431831487b94..43586ef4d7d3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -34,6 +34,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -693,6 +694,7 @@ public void test1dInputType(){ INDArray in = Nd4j.create(2, 10, 6); INDArray out = net.output(in); assertArrayEquals(new long[]{2,7,6}, out.shape()); + ToSameDiffTests.testToSameDiff(net, in, null); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java index 01d94e6dcb08..d13348368e90 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java @@ -26,7 +26,7 @@ * Created by Alex on 19/12/2016. */ @EqualsAndHashCode -public class CustomActivation extends BaseActivationFunction implements IActivation { +public class CustomActivation extends BaseActivationFunction { @Override public INDArray getActivation(INDArray in, boolean training) { return in; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java index ff79adf0a3de..d74708bd5627 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -28,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.LossFunctions; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 96ab25267799..18ddf29b47e5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -33,6 +33,7 @@ import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; @@ -556,6 +557,7 @@ public void testW2VInits(){ INDArray w = net.getParam("0_W"); assertEquals(vectors, w); + ToSameDiffTests.testToSameDiff(net, null, null); TestUtils.testModelSerialization(net); //Test same thing for embedding sequence layer: @@ -573,6 +575,7 @@ public void testW2VInits(){ .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) .nOut(4).build()) + .setInputType(InputType.feedForward(10)) .build(); net = new MultiLayerNetwork(conf); @@ -581,6 +584,7 @@ public void testW2VInits(){ w = net.getParam("0_W"); assertEquals(vectors, w); + ToSameDiffTests.testToSameDiff(net, null, null); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index e2c38bfad457..d5a76b051728 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -39,6 +39,7 @@ import org.deeplearning4j.nn.updater.MultiLayerUpdater; import org.deeplearning4j.nn.updater.UpdaterBlock; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.activations.Activation; @@ -451,6 +452,7 @@ public void checkSerialization() throws Exception { assertEquals(out, out2); + ToSameDiffTests.testToSameDiff(net, in, null); MultiLayerNetwork net2 = TestUtils.testModelSerialization(net); INDArray outDeser = net2.output(in, false); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java index 699d6bf552c8..26badf332424 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java @@ -21,6 +21,7 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.TemporaryFolder; @@ -93,6 +94,7 @@ public void testYoloActivateScoreBasic() { .layer(new Yolo2OutputLayer.Builder() .boundingBoxPriors(bbPrior) .build()) + .setInputType(InputType.convolutional(h, w, depth)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -159,6 +161,8 @@ public void testYoloActivateScoreBasic() { assertArrayEquals(new long[]{mb,1}, scoreArr1.shape()); assertArrayEquals(new long[]{mb,1}, scoreArr2.shape()); assertNotEquals(scoreArr1, scoreArr2); + + ToSameDiffTests.testToSameDiff(net, input, labels); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index 2fef54844c5b..983ce9e94efb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -39,6 +39,8 @@ import org.deeplearning4j.nn.updater.MultiLayerUpdater; import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.TestToSameDiff; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.TimeSeriesUtils; import org.junit.Test; @@ -176,6 +178,7 @@ public void compareImplementations(){ INDArray p1 = net1.params(); INDArray p2 = net2.params(); assertEquals(p1, p2); + ToSameDiffTests.testToSameDiff(net1, InputType.inferInputType(in), in, labels); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index 7ddc31220987..70d846e85125 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -22,10 +22,12 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -111,10 +113,12 @@ public void testSerialization(){ .list() .layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder() .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()) + .setInputType(InputType.recurrent(4, 10)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); + ToSameDiffTests.testToSameDiff(net, null, null); TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java index 93566050f39d..e1c2727a3da7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -35,6 +35,7 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -374,6 +375,11 @@ public static void testHelper(TestCase tc) { assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW assertEquals(tc.msg, out1, net4a.output(inNWC)); } + + ToSameDiffTests.testToSameDiff(tc.net1, inNCW, null); + ToSameDiffTests.testToSameDiff(tc.net2, inNCW, null); + ToSameDiffTests.testToSameDiff(tc.net3, inNWC, null); + ToSameDiffTests.testToSameDiff(tc.net4, inNWC, null); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 9f60d674dd00..ba902a7d8203 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -119,6 +120,7 @@ public void testLastTimeStepVertex() { assertEquals(expOut, outFwd); TestUtils.testModelSerialization(graph); + ToSameDiffTests.testToSameDiff(graph, in, null); } @Test diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index b2b789d589b0..ba546b7b3ee5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -244,6 +244,8 @@ public void testMismatchedInputLabelLength(){ if(msg == null) t.printStackTrace(); System.out.println(i); + + //TODO Add checks & error message in RNNOutput layer, etc in loss calculation before reshape. assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index 639d3fafdc1c..8be0b29b1e49 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -117,6 +118,7 @@ public void testSimpleRnn(){ } + ToSameDiffTests.testToSameDiff(net, null, null); TestUtils.testModelSerialization(net); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index 0d38d699cea5..47b33d93ddcf 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -16,6 +16,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -104,6 +105,7 @@ public void testTimeDistributed(){ MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2); out2 = net2.output(in); INDArray out3 = net3.output(in); + ToSameDiffTests.testToSameDiff(net3, in, labels); assertEquals(out2, out3); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java index 243909bd9b70..78003b6af860 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java @@ -122,7 +122,8 @@ public void validateInput(INDArray input) { } @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { return layerInput; } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index 4e923bf4aa63..58724c7782db 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffDenseVertex; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -172,6 +173,7 @@ public void testSameDiffDenseVertex() { INDArray outMbsd = netSD.output(newIn)[0]; INDArray outMb = netStandard.output(newIn)[0]; assertEquals(outMb, outMbsd); + ToSameDiffTests.testToSameDiff(netSD, in, l); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java index 8368d3869f26..d8f0dbdf18f9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaLayer; import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaVertex; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -130,6 +131,7 @@ public void testSameDiffLamdaLayerBasic(){ INDArray outMbsd = lambda.output(newIn)[0]; INDArray outMb = std.output(newIn)[0]; assertEquals(outMb, outMbsd); + ToSameDiffTests.testToSameDiff(lambda, in, labels); } } @@ -216,6 +218,7 @@ public void testSameDiffLamdaVertexBasic(){ INDArray outMbsd = lambda.output(newIn1, newIn2)[0]; INDArray outMb = std.output(newIn1, newIn2)[0]; assertEquals(outMb, outMbsd); + ToSameDiffTests.testToSameDiff(lambda, new INDArray[]{in1, in2}, new INDArray[]{labels}); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java index 9cbbccaa741b..bf16fb314407 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java @@ -48,7 +48,8 @@ protected MinimalSameDiffDense(){ } @Override - public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, SDVariable mask, + Map paramTable) { SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY); SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java index 7b78c14fccad..f2a1952b6b0d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java @@ -126,7 +126,8 @@ public void initializeParameters(Map params) { } @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java index 630b6059c169..f971238f745f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java @@ -102,7 +102,8 @@ public void initializeParameters(Map params){ } @Override - public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, SDVariable mask, + Map paramTable) { SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY); SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java index 9ada8a1dc8d6..45af1cc12192 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java @@ -27,7 +27,7 @@ public class SameDiffMSELossLayer extends SameDiffOutputLayer { @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) { + public SDVariable defineLayerAndLoss(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) { //MSE: 1/nOut * (input-labels)^2 SDVariable diff = layerInput.sub(labels); return diff.mul(diff).mean(1).sum(0); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java index cfead640f0ff..f2ff1ff50d34 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java @@ -44,7 +44,7 @@ public SameDiffMSEOutputLayer(int nIn, int nOut, Activation activation, WeightIn } @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) { + public SDVariable defineLayerAndLoss(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) { SDVariable z = sameDiff.mmul(layerInput, paramTable.get("W")).add(paramTable.get("b")); SDVariable out = activation.asSameDiff("out", sameDiff, z); //MSE: 1/nOut * (input-labels)^2 diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index 4282c5e62bbe..deea4be60125 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.misc.iter.WSTestDataSetIterator; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -208,7 +209,7 @@ public void testWithPreprocessorsMLN() { } } - public static class DupPreProcessor implements InputPreProcessor { + public static class DupPreProcessor extends BaseInputPreProcessor { @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr mgr) { return mgr.dup(ArrayType.ACTIVATIONS, input); @@ -220,7 +221,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w } @Override - public InputPreProcessor clone() { + public BaseInputPreProcessor clone() { return new DupPreProcessor(); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 3139b096a687..682eff5ccfc6 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -48,6 +48,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.deeplearning4j.util.ModelSerializer; import org.junit.*; import org.nd4j.linalg.activations.Activation; @@ -1041,6 +1042,7 @@ public void testEpochCounter() throws Exception { assertEquals(4, net.getLayerWiseConfigurations().getEpochCount()); + ToSameDiffTests.testToSameDiff(net, null, null); MultiLayerNetwork restored = TestUtils.testModelSerialization(net); assertEquals(4, restored.getLayerWiseConfigurations().getEpochCount()); } @@ -1242,6 +1244,7 @@ public void testZeroParamNet() throws Exception { net.fit(ds); + ToSameDiffTests.testToSameDiff(net, null, null); MultiLayerNetwork net2 = TestUtils.testModelSerialization(net); INDArray out2 = net2.output(ds.getFeatures()); assertEquals(out, out2); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java index aa552a859dd9..07b333f775a9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.samediff.ToSameDiffTests; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -88,6 +89,8 @@ public void testModelSerializerFrozenLayers() throws Exception { assertEquals(out, out2); + ToSameDiffTests.testToSameDiff(withFrozen, in, null); + //Sanity check on train mode: out = withFrozen.output(in, true); out2 = restored.output(in, true); @@ -141,5 +144,6 @@ public void testModelSerializerFrozenLayersCompGraph() throws Exception { //Sanity check on train mode: out = withFrozen.outputSingle(true, in); out2 = restored.outputSingle(true, in); + ToSameDiffTests.testToSameDiff(withFrozen, in, null); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/TestToSameDiff.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/TestToSameDiff.java new file mode 100644 index 000000000000..3d7c5829ef5b --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/TestToSameDiff.java @@ -0,0 +1,588 @@ +/* + * ****************************************************************************** + * * Copyright (c) 2020 Konduit K.K. + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.samediff; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder; +import org.deeplearning4j.nn.conf.Updater; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.PoolingType; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.util.ToSameDiffUtils; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.regularization.Regularization; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; +import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; + +@Slf4j +public class TestToSameDiff extends BaseDL4JTest { + + @Override + public long getTimeoutMilliseconds() { + return super.getTimeoutMilliseconds() * 10; + } + + private static final double eps = 1e-2; + + private static final String expectedSummary = "--- Summary ---\n" + + "Variables: 30 (8 with arrays)\n" + + "Functions: 20 \n" + + "SameDiff Function Defs: 0 \n" + + "Loss function variables: [loss]\n" + + "\n" + + "--- Variables ---\n" + + "- Name - - Array Shape - - Variable Type - - Data Type- - Output Of Function - - Inputs To Functions -\n" + + "input [-1, 1, 28, 28] PLACEHOLDER FLOAT [layer0/inputPreprocessor/reshape]\n" + + "layer0/inputPreprocessor/reshape - ARRAY FLOAT layer0/inputPreprocessor/reshape(reshape) [layer0/conv2d] \n" + + "layer0/b [1, 20] VARIABLE FLOAT [layer0/conv2d] \n" + + "layer0/W [20, 1, 5, 5] VARIABLE FLOAT [layer0/conv2d] \n" + + "layer0/conv2d - ARRAY FLOAT layer0/conv2d(conv2d) [layer1/maxpool2d] \n" + + "layer1/maxpool2d - ARRAY FLOAT layer1/maxpool2d(maxpool2d) [layer2/conv2d] \n" + + "layer2/b [1, 50] VARIABLE FLOAT [layer2/conv2d] \n" + + "layer2/W [50, 20, 5, 5] VARIABLE FLOAT [layer2/conv2d] \n" + + "layer2/conv2d - ARRAY FLOAT layer2/conv2d(conv2d) [layer3/maxpool2d] \n" + + "layer3/maxpool2d - ARRAY FLOAT layer3/maxpool2d(maxpool2d) [layer4/inputPreprocessor/reshape]\n" + + "layer4/inputPreprocessor/reshape - ARRAY FLOAT layer4/inputPreprocessor/reshape(reshape) [layer4/mmul] \n" + + "layer4/b [1, 500] VARIABLE FLOAT [layer4/add] \n" + + "layer4/W [800, 500] VARIABLE FLOAT [layer4/mmul] \n" + + "layer4/mmul - ARRAY FLOAT layer4/mmul(mmul) [layer4/add] \n" + + "layer4/add - ARRAY FLOAT layer4/add(add) [layer4/relu] \n" + + "layer4/relu - ARRAY FLOAT layer4/relu(relu) [layer5/mmul] \n" + + "layer5/b [1, 10] VARIABLE FLOAT [layer5/add] \n" + + "layer5/W [500, 10] VARIABLE FLOAT [layer5/mmul] \n" + + "layer5/mmul - ARRAY FLOAT layer5/mmul(mmul) [layer5/add] \n" + + "layer5/add - ARRAY FLOAT layer5/add(add) [layer5/softmax] \n" + + "layer5/softmax - ARRAY FLOAT layer5/softmax(softmax) [layer5/loss/ClipByValue]\n" + + "labels [-1, 10] PLACEHOLDER FLOAT [layer5/loss/multiply, layer5/loss/size_at]\n" + + "layer5/loss/ClipByValue - ARRAY FLOAT layer5/loss/ClipByValue(ClipByValue) [layer5/loss/log] \n" + + "layer5/loss/log - ARRAY FLOAT layer5/loss/log(log) [layer5/loss/multiply]\n" + + "layer5/loss/multiply - ARRAY FLOAT layer5/loss/multiply(multiply) [layer5/loss/neg] \n" + + "layer5/loss/neg - ARRAY FLOAT layer5/loss/neg(neg) [layer5/loss/reduce_sum]\n" + + "layer5/loss/reduce_sum - ARRAY FLOAT layer5/loss/reduce_sum(reduce_sum) [layer5/loss/divide]\n" + + "layer5/loss/size_at - ARRAY LONG layer5/loss/size_at(size_at) [layer5/loss/cast] \n" + + "layer5/loss/cast - ARRAY FLOAT layer5/loss/cast(cast) [layer5/loss/divide]\n" + + "loss - ARRAY FLOAT layer5/loss/divide(divide) \n" + + "\n" + + "\n" + + "--- Functions ---\n" + + " - Function Name - - Op - - Inputs - - Outputs - \n" + + "0 layer0/inputPreprocessor/reshape Reshape [input] [layer0/inputPreprocessor/reshape] \n" + + "1 layer0/conv2d Conv2D [layer0/inputPreprocessor/reshape, layer0/W, layer0/b] [layer0/conv2d] \n" + + "2 layer1/maxpool2d MaxPooling2D [layer0/conv2d] [layer1/maxpool2d] \n" + + "3 layer2/conv2d Conv2D [layer1/maxpool2d, layer2/W, layer2/b] [layer2/conv2d] \n" + + "4 layer3/maxpool2d MaxPooling2D [layer2/conv2d] [layer3/maxpool2d] \n" + + "5 layer4/inputPreprocessor/reshape Reshape [layer3/maxpool2d] [layer4/inputPreprocessor/reshape] \n" + + "6 layer4/mmul Mmul [layer4/inputPreprocessor/reshape, layer4/W] [layer4/mmul] \n" + + "7 layer4/add AddOp [layer4/mmul, layer4/b] [layer4/add] \n" + + "8 layer4/relu RectifiedLinear [layer4/add] [layer4/relu] \n" + + "9 layer5/mmul Mmul [layer4/relu, layer5/W] [layer5/mmul] \n" + + "10 layer5/add AddOp [layer5/mmul, layer5/b] [layer5/add] \n" + + "11 layer5/softmax SoftMax [layer5/add] [layer5/softmax] \n" + + "12 layer5/loss/ClipByValue ClipByValue [layer5/softmax] [layer5/loss/ClipByValue] \n" + + "13 layer5/loss/log Log [layer5/loss/ClipByValue] [layer5/loss/log] \n" + + "14 layer5/loss/multiply MulOp [layer5/loss/log, labels] [layer5/loss/multiply] \n" + + "15 layer5/loss/neg Negative [layer5/loss/multiply] [layer5/loss/neg] \n" + + "16 layer5/loss/reduce_sum Sum [layer5/loss/neg] [layer5/loss/reduce_sum] \n" + + "17 layer5/loss/size_at SizeAt [labels] [layer5/loss/size_at] \n" + + "18 layer5/loss/cast Cast [layer5/loss/size_at] [layer5/loss/cast] \n" + + "19 layer5/loss/divide DivOp [layer5/loss/reduce_sum, layer5/loss/cast] [loss] \n"; + + @Override + public DataType getDataType() { + return DataType.DOUBLE; + } + + public static void testSameDiffInference(MultiLayerNetwork network, SameDiff sameDiff, INDArray input, + String name) { + INDArray dl4j = network.output(input.dup()); + INDArray sd = sameDiff.batchOutput() + .input("input", input.dup()) + .output(sameDiff.outputs().get(0)) + .outputSingle(); + + if (sd.isNaN().any() && dl4j.isNaN().any()) { + return; + } + +// assertEquals("Sums of DL4J and SameDiff outputs differ for " + name, dl4j.sumNumber().doubleValue(), sd.sumNumber().doubleValue(), eps); + + assertTrue("Output of DL4J and SameDiff differ for " + name, dl4j.equalsWithEps(sd, eps)); + } + + public static void testWeights(MultiLayerNetwork network, SameDiff sameDiff, String name) { + List names = ToSameDiffUtils.getScopeNames(network.getLayers()); + for (int i = 0; i < network.getnLayers(); i++) { + String nameScope = names.get(i); + for (Map.Entry entry : network.getLayer(i).paramTable().entrySet()) { + String paramName = entry.getKey(); + INDArray dl4j = entry.getValue(); + INDArray sd = sameDiff.getArrForVarName(nameScope + "/" + paramName); + + assertTrue("Weight " + nameScope + "/" + paramName + " differs for" + name, dl4j.equalsWithEps(sd, eps)); + } + } + } + + public static void testBackprop(MultiLayerNetwork network, SameDiff sameDiff, INDArray input, INDArray labels) { + network.setInput(input); + network.setLabels(labels); + + network.computeGradientAndScore(); + + int batchSize = (int) input.size(0); + + double dl4jScore = network.score(); + double sdScore = sameDiff.batchOutput() + .input("labels", labels) + .input("input", input) + .output(sameDiff.getLossVariables().get(0)) + .outputSingle().sumNumber().doubleValue(); + assertEquals("Losses differed", dl4jScore, sdScore, eps); + + Map dl4jGradient = network.gradient().gradientForVariable(); + + boolean has2ndLayer = dl4jGradient.containsKey("1_W"); + + INDArray dl4jWeightGrad = dl4jGradient.get("0_W"); + INDArray dl4jBiasGrad = dl4jGradient.get("0_b"); + + Map placeholderMap = new HashMap<>(); + placeholderMap.put("labels", labels); + placeholderMap.put("input", input); + + Set gradientVars = new HashSet<>(); + + for (String k : sameDiff.variableMap().keySet()) { + if (sameDiff.getVariable(k).dataType().isFPType()) { + gradientVars.add(k); + } + } + + Map sameDiffGradient = sameDiff.calculateGradients(placeholderMap, gradientVars); + + // SameDiff does its batch div in the gradient calc, however DL4J does it afterwards + for (Map.Entry entry : sameDiffGradient.entrySet()) { + entry.setValue(entry.getValue().mul(batchSize)); + } + + INDArray sdWeightGrad = sameDiffGradient.get("layer0/W"); + INDArray sdBiasGrad = sameDiffGradient.get("layer0/b"); + + assertTrue("Weight 0 gradient differs", dl4jWeightGrad.equalsWithEps(sdWeightGrad, eps)); + assertTrue("Bias 0 gradient differs", dl4jBiasGrad.equalsWithEps(sdBiasGrad, eps)); + + if (has2ndLayer) { + + INDArray dl4jWeightGrad2 = dl4jGradient.get("1_W"); + INDArray dl4jBiasGrad2 = dl4jGradient.get("1_b"); + + INDArray sdWeightGrad2 = sameDiffGradient.get("layer1/W"); + INDArray sdBiasGrad2 = sameDiffGradient.get("layer1/b"); + + assertTrue("Weight 1 gradient differs", dl4jWeightGrad2.equalsWithEps(sdWeightGrad2, eps)); + assertTrue("Bias 1 gradient differs", dl4jBiasGrad2.equalsWithEps(sdBiasGrad2, eps)); + } + } + + public static void testSameDiffInference(ComputationGraph network, SameDiff sameDiff, INDArray input, String name) { + INDArray dl4j = network.output(input)[0]; + INDArray sd = sameDiff.batchOutput() + .input("in", input) + .output(sameDiff.outputs().get(0)) + .outputSingle(); + + assertTrue("Output of DL4J and SameDiff differ for " + name, dl4j.equalsWithEps(sd, eps)); + } + + @Test + public void testMcXent() { + Nd4j.getRandom().setSeed(123); + + ILossFunction loss = new LossMCXENT(); + IActivation activation = new ActivationSoftmax(); + + INDArray input = Nd4j.rand(5, 4); + INDArray labels = Nd4j.rand(5, 4); + + INDArray dl4grad = loss.computeGradient(labels.dup(), input.dup(), activation, null); + + SameDiff sameDiff = SameDiff.create(); + SDVariable inputVar = sameDiff.placeHolder("input", input.dataType(), input.shape()); + SDVariable labelsVar = sameDiff.placeHolder("labels", labels.dataType(), labels.shape()); + + SDVariable out = sameDiff.nn.softmax(inputVar); + // not dividing by batch size as dl4j does it later + SDVariable lossVar = sameDiff.math.log(out).mul(labelsVar).neg().sum(); + + sameDiff.setLossVariables(lossVar); + + Map placeholderMap = new HashMap<>(); + placeholderMap.put("input", input.dup()); + placeholderMap.put("labels", labels.dup()); + + sameDiff.createGradFunction("input"); + + INDArray sdGrad = sameDiff.calculateGradients(placeholderMap, lossVar.name(), "input").get("input"); + + assertTrue(dl4grad.equalsWithEps(sdGrad, eps)); + } + + @Test + public void testSimple() throws IOException { + int seed = 123; + + Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); + + boolean[] useDenses = {false}; + Updater[] updaters = {Updater.SGD, Updater.ADAM, Updater.ADAMAX, Updater.ADADELTA, Updater.NESTEROVS, + Updater.NADAM/*, Updater.ADAGRAD, Updater.RMSPROP*/, Updater.NONE}; + Regularization[] regularizations = {null}; // {new L2Regularization(0.0005), new L1Regularization(0.005), +// new WeightDecay(0.03, true)}; + LossFunction[] lossFunctions = {LossFunction.MSE, LossFunction.L1/*, LossFunction.MCXENT*/, + /*LossFunction.COSINE_PROXIMITY,*/ LossFunction.HINGE, + LossFunction.SQUARED_HINGE/*, LossFunction.KL_DIVERGENCE*/, LossFunction.MEAN_ABSOLUTE_ERROR, + LossFunction.L2/*, LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR*//*, + LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR*//*, LossFunction.POISSON*/, LossFunction.WASSERSTEIN}; + + Activation[] activations = {/*Activation.CUBE, */Activation.ELU, Activation.HARDSIGMOID, Activation.HARDTANH, + Activation.IDENTITY, Activation.LEAKYRELU, Activation.RATIONALTANH, Activation.RELU, Activation.RELU6, + Activation.RRELU, Activation.SIGMOID/*, Activation.SOFTMAX*/, Activation.SOFTPLUS, Activation.SOFTSIGN, + Activation.TANH, Activation.RECTIFIEDTANH, Activation.SELU, Activation.SWISH, + Activation.THRESHOLDEDRELU, Activation.GELU, Activation.MISH}; + + List failures = new ArrayList<>(); + + for (Updater updater : updaters) { + for (LossFunction lossFunction : lossFunctions) { + for (Activation activation : activations) { + for (boolean useDense : useDenses) { + for (Regularization regularization : regularizations) { + + if (updater == Updater.CUSTOM) { + continue; + } + + IUpdater iUpdater = updater.getIUpdaterWithDefaultConfig(); + + log.info("Test with {}, {}, {}, {}, and {}", useDense ? "dense layer" : "no dense layer", + regularization, activation, lossFunction, iUpdater); + + try { + Nd4j.getRandom().setSeed(seed); + + ListBuilder partial = new NeuralNetConfiguration.Builder() + .seed(seed) + .dataType(DataType.DOUBLE) + .updater(iUpdater) + .regularization( + regularization != null ? Collections.singletonList(regularization) + : Collections.emptyList()) + .regularizationBias( + regularization != null ? Collections.singletonList(regularization) + : Collections.emptyList()) + .list(); + + if (useDense) { + partial.layer(new DenseLayer.Builder() + .activation(Activation.RELU) + .nOut(4).build()); + } + + MultiLayerConfiguration config = partial + .layer(new OutputLayer.Builder(lossFunction) + .activation(activation).nIn(4).nOut(3).build()) + .setInputType(InputType.feedForward(4)) + .validateOutputLayerConfig(false) + .build(); + + MultiLayerNetwork network = new MultiLayerNetwork(config); + network.init(); + + Nd4j.getRandom().setSeed(seed); + + INDArray example = Nd4j.rand(5, 4).mul(2); + DataSet ds = new DataSet(Nd4j.rand(5, 4).mul(2), Nd4j.rand(5, 3).mul(2)); + DataSetIterator iter = new SingletonDataSetIterator(ds); + + // --- training tests --- + + // train DL4J first + network.fit(iter, 1); + assertEquals(1, network.getIterationCount()); + assertEquals(1, network.getEpochCount()); + iter.reset(); + + // copy (w/ params and updater state) + + SameDiff mnistSameDiff; + try { + mnistSameDiff = network.toSameDiff(null, false, false); + } catch (UnsupportedOperationException e) { + continue; + } + testSameDiffInference(network, mnistSameDiff, example, "Post DL4J Training"); + +// testBackprop(network, mnistSameDiff, ds.getFeatures().dup(), ds.getLabels().dup()); + + // train 2 more epochs +// iter.reset(); +// mnistSameDiff.fit(iter, 1); +// assertEquals(1, mnistSameDiff.getTrainingConfig().getIterationCount()); +// assertEquals(1, mnistSameDiff.getTrainingConfig().getEpochCount()); +// +// iter.reset(); +// network.fit(iter, 1); +// assertEquals(1, network.getIterationCount()); +// assertEquals(1, network.getEpochCount()); +// +// testSameDiffInference(network, mnistSameDiff, example, "Post 1st Training"); + + testWeights(network, mnistSameDiff, "Copy"); + + iter.reset(); + mnistSameDiff.fit(iter, 1); + assertEquals(2, mnistSameDiff.getTrainingConfig().getIterationCount()); + assertEquals(2, mnistSameDiff.getTrainingConfig().getEpochCount()); + + iter.reset(); + network.fit(iter, 1); + assertEquals(2, network.getIterationCount()); + assertEquals(2, network.getEpochCount()); + + testWeights(network, mnistSameDiff, "Post Train"); + testSameDiffInference(network, mnistSameDiff, example, "Post 2nd Training"); + } catch (AssertionError ae) { + ae.printStackTrace(); + failures.add((useDense ? "Dense Layer " : "No Dense Layer ") + " with " + regularization + + ", " + activation + + ", " + lossFunction + + ", and " + iUpdater); + } + } + } + } + } + } + + log.info(" --- Failures --- "); + for (String f : failures) { + log.info(f); + } + + assertTrue("There were failed tests", failures.isEmpty()); + + } + + @Test + public void testConversionAndTraining() throws IOException { + int seed = 123; + int outputNum = 10; + + Nd4j.getRandom().setSeed(seed); + + MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() + .seed(seed) + .l2(0.0005) + .l2Bias(0.0005) + .weightInit(WeightInit.XAVIER) + .updater(new Adam()) + .list() + .layer(new ConvolutionLayer.Builder(5, 5) + .stride(1, 1) + .nOut(20) + .activation(Activation.IDENTITY) + .build()) + .layer(new SubsamplingLayer.Builder(PoolingType.MAX) + .kernelSize(2, 2) + .stride(2, 2) + .build()) + .layer(new ConvolutionLayer.Builder(5, 5) + .stride(1, 1) + .nOut(50) + .activation(Activation.IDENTITY) + .build()) + .layer(new SubsamplingLayer.Builder(PoolingType.MAX) + .kernelSize(2, 2) + .stride(2, 2) + .build()) + .layer(new DenseLayer.Builder().activation(Activation.RELU) + .nOut(500).build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nOut(outputNum) + .activation(Activation.SOFTMAX) + .build()) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + + MultiLayerNetwork network = new MultiLayerNetwork(config); + network.init(); + + Nd4j.getRandom().setSeed(seed); + SameDiff mnistSameDiff = network.toSameDiff(null, false, false); + + assertEquals("More than one output", 1, mnistSameDiff.outputs().size()); + assertEquals("More than one loss", 1, mnistSameDiff.getLossVariables().size()); + assertNotNull(mnistSameDiff.getTrainingConfig()); + + assertEquals("Summaries aren't equal", expectedSummary, mnistSameDiff.summary()); + + MnistDataSetIterator trainData = new MnistDataSetIterator(2, 2); + + INDArray example = trainData.next().getFeatures().dup(); + + testSameDiffInference(network, mnistSameDiff, example, "Inference"); + + // --- training tests --- + + // train DL4J first + network.fit(trainData, 1); + trainData.reset(); + + // copy (w/ params and updater state) + + mnistSameDiff = network.toSameDiff(null, false, false); + testSameDiffInference(network, mnistSameDiff, example, "Post DL4J Training"); + + // train 2 more epochs + trainData.reset(); + mnistSameDiff.fit(trainData, 1); + + trainData.reset(); + network.fit(trainData, 1); + + testSameDiffInference(network, mnistSameDiff, example, "Post 2nd Training"); + } + + @Test + public void testConversionAndTrainingGraph() throws IOException { + int seed = 123; + int outputNum = 10; + + MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() + .seed(seed) +// .l2(0.0005) +// .l2Bias(0.0005) +// .weightInit(WeightInit.XAVIER) + .updater(new Adam(eps)) + .list() + .layer(new ConvolutionLayer.Builder(5, 5) + .stride(1, 1) + .nOut(20) + .activation(Activation.IDENTITY) + .build()) + .layer(new SubsamplingLayer.Builder(PoolingType.MAX) + .kernelSize(2, 2) + .stride(2, 2) + .build()) + .layer(new ConvolutionLayer.Builder(5, 5) + .stride(1, 1) + .nOut(50) + .activation(Activation.IDENTITY) + .build()) + .layer(new SubsamplingLayer.Builder(PoolingType.MAX) + .kernelSize(2, 2) + .stride(2, 2) + .build()) + .layer(new DenseLayer.Builder().activation(Activation.RELU) + .nOut(500).build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nOut(outputNum) + .activation(Activation.SOFTMAX) + .build()) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(config); + net.init(); + + ComputationGraph graph = net.toComputationGraph(); + graph.init(); + + Map inputTypes = new HashMap<>(); + inputTypes.put("in", InputType.convolutionalFlat(28, 28, 1)); + SameDiff mnistSameDiff = graph.toSameDiff(inputTypes, true, true); + + assertEquals("More than one output", 1, mnistSameDiff.outputs().size()); + assertEquals("More than one loss", 1, mnistSameDiff.getLossVariables().size()); + assertNotNull(mnistSameDiff.getTrainingConfig()); + + MnistDataSetIterator trainData = new MnistDataSetIterator(10, 10); + + INDArray example = trainData.next().getFeatures().dup(); + + testSameDiffInference(graph, mnistSameDiff, example, "Inference"); + + // --- training tests --- + + // train DL4J first + graph.fit(trainData, 2); + trainData.reset(); + + // copy (w/ params and updater state) + + mnistSameDiff = graph.toSameDiff(inputTypes, true, false); + testSameDiffInference(graph, mnistSameDiff, example, "Post DL4J Training"); + + // train 2 more epochs + trainData.reset(); + mnistSameDiff.fit(trainData, 2); + + trainData.reset(); + graph.fit(trainData, 2); + + testSameDiffInference(graph, mnistSameDiff, example, "Post 2nd Training"); + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/ToSameDiffTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/ToSameDiffTests.java new file mode 100644 index 000000000000..6b1aa9f19951 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/ToSameDiffTests.java @@ -0,0 +1,763 @@ +/* + * ****************************************************************************** + * * Copyright (c) 2020 Konduit K.K. + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.samediff; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.common.reflect.ClassPath; +import com.google.common.reflect.ClassPath.ClassInfo; +import java.io.IOException; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.api.layers.IOutputLayer; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.dropout.IDropout; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer; +import org.deeplearning4j.nn.conf.layers.CnnLossLayer; +import org.deeplearning4j.nn.conf.layers.Convolution1D; +import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; +import org.deeplearning4j.nn.conf.layers.Convolution2D; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerWithLoss; +import org.deeplearning4j.nn.conf.layers.Pooling1D; +import org.deeplearning4j.nn.conf.layers.Pooling2D; +import org.deeplearning4j.nn.conf.layers.RnnLossLayer; +import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex; +import org.deeplearning4j.nn.layers.BaseOutputLayer; +import org.deeplearning4j.nn.layers.LossLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.util.ToSameDiffUtils; +import org.junit.runner.Result; +import org.junit.runner.notification.RunListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.base.Preconditions; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.ILossFunction; + + +@Slf4j +public class ToSameDiffTests extends RunListener { + + public static boolean SKIP_UNIMPLEMENTED = true; + public static boolean FAIL_FAST = true; + public static boolean FAIL_IF_MISSING = false; + // makes it show up in IDEA test runs + public static boolean PRINT_AFTER_EVERY = true; + + private static final Set failurePointLayers = new HashSet<>(); + private static final Set failurePointVertices = new HashSet<>(); + private static final Set failureLosses = new HashSet<>(); + + private static void cleanupLayers(Set> layers){ + if(layers.remove(Convolution1D.class)) + layers.add(Convolution1DLayer.class); + + if(layers.remove(Convolution2D.class)) + layers.add(ConvolutionLayer.class); + + if(layers.remove(Pooling1D.class)) + layers.add(Subsampling1DLayer.class); + + if(layers.remove(Pooling2D.class)) + layers.add(SubsamplingLayer.class); + } + + private static void cleanupLosses(Set> layers){ + + } + + private static void cleanupDropouts(Set> layers){ + + } + + private static void cleanupActivations(Set> layers){ + + } + + private static void cleanupPreprocessors(Set> layers){ + + } + + private static void cleanupVertices(Set> layers){ + + } + + private static Set> findClasses(Class superClass, String topPackage) { + + Set infos; + try { + infos = ClassPath.from(superClass.getClassLoader()).getTopLevelClassesRecursive(topPackage); + } catch (IOException e) { + infos = new HashSet<>(); + } + + Set> classes = new HashSet<>(); + for(ClassInfo ci : infos){ + Class c = ci.load(); + if(superClass.isAssignableFrom(c) && + !Modifier.isAbstract(c.getModifiers()) && + !c.isInterface() && + !c.getSimpleName().toLowerCase().contains("custom")) + classes.add(c.asSubclass(superClass)); + } + return classes; + } + + private static Set> findLayers() { + Set> ret = findClasses(Layer.class, "org.deeplearning4j.nn.conf.layers"); + cleanupLayers(ret); + return ret; + } + + private static Set> findLosses() { + Set> ret = findClasses(ILossFunction.class, "org.nd4j.linalg.lossfunctions"); + cleanupLosses(ret); + return ret; + } + + private static Set> findDropouts() { + Set> ret = findClasses(IDropout.class, "org.deeplearning4j.nn.conf.dropout"); + cleanupDropouts(ret); + return ret; + } + + private static Set> findActivations() { + Set> ret = findClasses(IActivation.class, "org.nd4j.linalg.activations"); + cleanupActivations(ret); + return ret; + } + + private static Set> findPreprocessors() { + Set> ret = findClasses(InputPreProcessor.class, "org.deeplearning4j.nn.conf.preprocessor"); + cleanupPreprocessors(ret); + return ret; + } + + private static Set> findVertices() { + Set> ret = findClasses(GraphVertex.class, "org.deeplearning4j.nn.graph.vertex.impl"); + cleanupVertices(ret); + return ret; + } + + private enum Stage{ + Conversion, Output, Loss; + + public Set> testedLayers = new HashSet<>(); + public Set> testedLosses = new HashSet<>(); + public Set> testedDropouts = new HashSet<>(); + public Set> testedActivations = new HashSet<>(); + public Set> testedPreprocessors = new HashSet<>(); + public Set> testedVertices = new HashSet<>(); + + public void cleanup(){ + cleanupLayers(testedLayers); + cleanupLosses(testedLosses); + cleanupDropouts(testedDropouts); + cleanupActivations(testedActivations); + cleanupPreprocessors(testedPreprocessors); + cleanupVertices(testedVertices); + } + + private static Set minusStr(Set> a, Set> b){ + Set ret = new HashSet<>(); + for(Class c : a){ + if(!b.contains(c)) + ret.add(c.getSimpleName()); + } + return ret; + } + + public int check(Set> foundLayers, + Set> foundLosses, + Set> foundDropouts, + Set> foundActivations, + Set> foundPreprocessors, + Set> foundVertices + ){ + + if(this == Stage.Loss){ + // only care about losses & output/loss layers here + foundDropouts.clear(); + foundActivations.clear(); + foundPreprocessors.clear(); + foundVertices.clear(); + + Set> old = foundLayers; + foundLayers = new HashSet<>(); + for(Class layer : old){ + if(LayerWithLoss.class.isAssignableFrom(layer)) + foundLayers.add(layer); + } + } + + Set missingLayers = minusStr(foundLayers, testedLayers); + Set missingLosses = minusStr(foundLosses, testedLosses); + Set missingDropouts = minusStr(foundDropouts, testedDropouts); + Set missingActivations = minusStr(foundActivations, testedActivations); + Set missingPreprocessors = minusStr(foundPreprocessors, testedPreprocessors); + Set missingVertices = minusStr(foundVertices, testedVertices); + + if(this != Stage.Loss) + log.info(" --- ToSameDiff {} Tests --- ", name()); + else + log.info(" --- ToSameDiff Loss Tests (only layers that define losses and loss functions are shown) --- "); + + log.info("Missing Layers: {}", missingLayers); + + if(this != Stage.Loss) { + log.info("Missing Activations: {}", missingActivations); + } + + log.info("Missing Losses: {}", missingLosses); + + if(this != Stage.Loss) { + log.info("Missing Preprocessors: {}", missingPreprocessors); + log.info("Missing Dropouts: {}", missingDropouts); + log.info("Missing Vertices: {}", missingVertices); + } + + return missingLayers.size() + missingLosses.size() + missingDropouts.size() + + missingActivations.size() + missingPreprocessors.size() + missingVertices.size(); + } + + public void record(InputPreProcessor preProcessor){ + if(preProcessor != null) { + testedPreprocessors.add(preProcessor.getClass()); + } + } + + public void record(IActivation activation){ + if(activation != null){ + testedActivations.add(activation.getClass()); + } + } + + public void record(ILossFunction lossFunction){ + if(lossFunction != null){ + testedLosses.add(lossFunction.getClass()); + } + } + + public void record(IDropout dropout){ + if(dropout != null){ + testedDropouts.add(dropout.getClass()); + } + } + + public void record(Layer layer){ + if(layer == null) + return; + + testedLayers.add(layer.getClass()); + record(layer.getIDropout()); + + if(layer instanceof BaseWrapperLayer){ + record(((BaseWrapperLayer) layer).getUnderlying()); + } else if(layer instanceof FrozenLayer) { + record(((FrozenLayer) layer).getLayer()); + } else if(layer instanceof Bidirectional){ + record(((Bidirectional) layer).getFwd()); + record(((Bidirectional) layer).getBwd()); + } + + if(layer instanceof FeedForwardLayer){ + record(((FeedForwardLayer) layer).getActivationFn()); + } + + if(layer instanceof org.deeplearning4j.nn.conf.layers.BaseOutputLayer){ + record(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layer).getLossFn()); + } else if(layer instanceof CnnLossLayer){ + record(((CnnLossLayer) layer).getLossFn()); + } else if(layer instanceof RnnLossLayer){ + record(((RnnLossLayer) layer).getLossFn()); + } else if(layer instanceof org.deeplearning4j.nn.conf.layers.LossLayer){ + record(((org.deeplearning4j.nn.conf.layers.LossLayer) layer).getLossFn()); + } else if(layer instanceof Cnn3DLossLayer){ + record(((Cnn3DLossLayer) layer).getLossFn()); + } + } + + public void record(GraphVertex vertex){ + if(vertex == null) + return; + + testedVertices.add(vertex.getClass()); + if(vertex.hasLayer()){ + record(vertex.getLayer().conf().getLayer()); + + } + + if(vertex instanceof LayerVertex) + record(((LayerVertex) vertex).getLayerPreProcessor()); + + } + } + + public static void testToSameDiff(@NonNull MultiLayerNetwork network, INDArray input, INDArray labels){ + testToSameDiff(network, null, input, labels); + } + + private static ILossFunction getLossFn(org.deeplearning4j.nn.api.Layer layer){ + ILossFunction lossFn = null; + if(layer instanceof BaseOutputLayer){ + lossFn = ((BaseOutputLayer) layer).getLossFn(); + } else if(layer instanceof LossLayer){ + lossFn = ((LossLayer) layer).getLossFn(); + } else if(layer instanceof org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer){ + lossFn = ((org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer) layer).getLossFn(); + } else if(layer instanceof org.deeplearning4j.nn.layers.convolution.CnnLossLayer){ + lossFn = ((org.deeplearning4j.nn.layers.convolution.CnnLossLayer) layer).getLossFn(); + } else if(layer instanceof org.deeplearning4j.nn.layers.recurrent.RnnLossLayer){ + lossFn = ((org.deeplearning4j.nn.layers.recurrent.RnnLossLayer) layer).getLossFn(); + } + return lossFn; + } + + public static void testToSameDiff(@NonNull MultiLayerNetwork network, InputType inputType, INDArray input, INDArray labels){ + + for(int i = 0 ; i < network.getnLayers() ; i++){ + Layer layer = network.getLayer(i).conf().getLayer(); + Stage.Conversion.record(layer); + Stage.Conversion.record(network.getLayerWiseConfigurations().getInputPreProcess(i)); + } + + SameDiff sameDiff; + try{ + sameDiff = network.toSameDiff(inputType, true, true); + } catch (UnsupportedOperationException e){ + if(!SKIP_UNIMPLEMENTED) + throw e; + else + return; + } catch (IllegalStateException e){ + if((e.getMessage().contains(" convert to SameDiff with different regularizations") || + e.getMessage().contains(" convert to SameDiff with different IUpdaters")) && SKIP_UNIMPLEMENTED) + return; + else + throw e; + } + + if(input == null){ + long[] inputShape = sameDiff.getVariable("input").placeholderShape(); + for(int i = 0 ; i < inputShape.length ; i++){ + if(inputShape[i] == -1) + inputShape[i] = 1; + } + + input = Nd4j.rand(inputShape); + } + + for(int i = 0 ; i < network.getnLayers() ; i++){ + Layer layer = network.getLayer(i).conf().getLayer(); + Stage.Output.record(layer); + Stage.Output.record(network.getLayerWiseConfigurations().getInputPreProcess(i)); + } + + List activations = network.feedForward(input); + activations.remove(0); + + List sdActivationVariables = new ArrayList<>(); + + + List namesByLayer = ToSameDiffUtils.getScopeNames(network.getLayers()); + + List layerClassNames = new ArrayList<>(); + for(int i = 0 ; i < network.getnLayers() ; i++){ + org.deeplearning4j.nn.conf.layers.Layer config = network.getLayerWiseConfigurations().getConf(i).getLayer(); + + String scope = namesByLayer.get(i); + List scopeVars = sameDiff.getVariablesInScope(scope); + layerClassNames.add(config.getClass().getSimpleName()); + if(scopeVars.size() > 0) { + + SDVariable lastVar = null; + for(int j = scopeVars.size() - 1 ; j >= 0 ; j--){ + SDVariable variable = scopeVars.get(j); + + if(!variable.name().contains("/loss/") && !variable.name().endsWith("loss") && !variable.name().endsWith("labels")){ + lastVar = variable; + break; + } + + } + + if(lastVar != null) + sdActivationVariables.add(lastVar.name()); + else + sdActivationVariables.add(sdActivationVariables.get(sdActivationVariables.size() - 1)); + } else + sdActivationVariables.add(sdActivationVariables.get(sdActivationVariables.size() - 1)); + } + + Map sdActivations = sameDiff.batchOutput() + .output(sdActivationVariables.toArray(new String[0])) + .input("input", input) + .output(); + + + assertEquals("Sizes of DL4J activations and found SameDiff activations differ", activations.size(), sdActivationVariables.size()); + + List> messages = new ArrayList<>(); + boolean failed = false; + for(int i = 0 ; i < sdActivationVariables.size() ; i++){ + INDArray sd = sdActivations.get(sdActivationVariables.get(i)); + INDArray dl4j = activations.get(i); + + if(! sd.equalsWithEps(dl4j, 1e-3)) { + + if(!failed) + failurePointLayers.add(network.getLayer(i).conf().getLayer().getClass().getSimpleName()); + + failed = true; + if(FAIL_FAST) + fail("DL4J activation and SameDiff activation not equal for Layer " + layerClassNames.get(i) + " and SDVariable " + sdActivationVariables.get(i)); + else + messages.add(new Pair<>(layerClassNames.get(i), sdActivationVariables.get(i))); + } + } + + StringBuilder message = new StringBuilder("DL4J activation and SameDiff activation not equal for "); + + for(Pair pair : messages) + message.append("Layer ").append(pair.getFirst()).append(" and SDVariable ").append(pair.getSecond()) + .append(", "); + + assertEquals(message.toString(), 0, messages.size()); + + if(labels != null){ + + for(int i = 0 ; i < network.getnLayers() ; i++){ + Layer layer = network.getLayer(i).conf().getLayer(); + Stage.Loss.record(layer); + Stage.Loss.record(network.getLayerWiseConfigurations().getInputPreProcess(i)); + } + + INDArray output = network.output(input).dup(); + network.setLabels(labels); + network.computeGradientAndScore(); + double score = network.score() - network.calcRegularizationScore(true); + + Map sdOutputs = sameDiff.batchOutput() + .output(sameDiff.outputs().get(0), sameDiff.getLossVariables().get(0)) + .input("input", input) + .input("labels", labels) + .output(); + + INDArray sdLoss = sdOutputs.get(sameDiff.getLossVariables().get(0)); + INDArray sdOutput = sdOutputs.get(sameDiff.outputs().get(0)); + + + assertTrue("Outputs don't match for original network and SameDiff version", sdOutput.equalsWithEps(output, 1e-3)); + + double sdScore = sdLoss.sumNumber().doubleValue(); + + ILossFunction lossFn = getLossFn(network.getOutputLayer()); + try { + assertEquals("Losses don't match for original network and SameDiff version" + (lossFn != null ? + " for loss function " + lossFn.getClass().getSimpleName() : ""), + sdScore, score, 1e-3); + } catch (AssertionError ae){ + if(ae.getMessage().contains("Losses don't match") && lossFn != null){ + failureLosses.add(lossFn.getClass().getSimpleName()); + } + throw ae; + } + } + + if(PRINT_AFTER_EVERY) { + printResults(); + } + + } + + public static void testToSameDiff(@NonNull ComputationGraph graph, @NonNull INDArray inputs, INDArray labels){ + INDArray[] labelsArray = null; + if(labels != null) + labelsArray = new INDArray[]{labels}; + + testToSameDiff(graph, new INDArray[]{inputs}, labelsArray); + } + + public static void testToSameDiff(@NonNull ComputationGraph graph, @NonNull INDArray[] inputs, INDArray[] labels){ + testToSameDiff(graph, inputs, labels, null); + } + + public static void testToSameDiff(@NonNull ComputationGraph graph, @NonNull INDArray[] inputs, INDArray[] labels, InputType[] inputTypes){ + Preconditions.checkArgument(inputs.length == graph.getConfiguration().getNetworkInputs().size(), + "Didn't supply the right number of inputs: expected %s, got %s", graph.getConfiguration().getNetworkInputs().size(), inputs.length); + + Map inputTypesMap = new HashMap<>(); + Map inputsMap = new HashMap<>(); + + for(int i = 0 ; i < inputs.length ; i++){ + String name = graph.getConfiguration().getNetworkInputs().get(i); + inputsMap.put(name, inputs[i]); + + if(inputTypes != null && inputTypes.length > i && inputTypes[i] != null) + inputTypesMap.put(name, inputTypes[i]); + else + inputTypesMap.put(name, InputType.inferInputType(inputs[i])); + } + + InputType[] inputVertTypes = new InputType[inputTypesMap.size()]; + int j = 0; + for(String inputName : graph.getConfiguration().getNetworkInputs()){ + inputVertTypes[j] = inputTypesMap.get(inputName); + j++; + } + + try { + graph.getConfiguration().getLayerActivationTypes(true, inputVertTypes); + } catch (Exception e){ + log.warn("Error getting activation types and adding preprocessors for graph", e); + } + + for(GraphVertex vertex : graph.getVertices()){ + Stage.Conversion.record(vertex); + } + + SameDiff sameDiff; + try{ + sameDiff = graph.toSameDiff(inputTypesMap, true, true); + } catch (UnsupportedOperationException e){ + if(!SKIP_UNIMPLEMENTED) + throw e; + else + return; + } catch (IllegalStateException e){ + if((e.getMessage().contains(" convert to SameDiff with different regularizations") || + e.getMessage().contains(" convert to SameDiff with different IUpdaters") || + e.getMessage().equals("Dimension must be set for toSameDiff conversion.")) && + SKIP_UNIMPLEMENTED) + return; + else + throw e; + } + + for(GraphVertex vertex : graph.getVertices()){ + Stage.Output.record(vertex); + } + + Map activations = graph.feedForward(inputs, false); + + for(String inputName : inputsMap.keySet()) + activations.remove(inputName); + + List activationKeys = new ArrayList<>(); + for(String n : graph.getConfiguration().getTopologicalOrderStr()){ + if(activations.containsKey(n)) + activationKeys.add(n); + } + + Map sdActivationVariables = new HashMap<>(); + for(String vertexName : activationKeys){ + List scopeVars = sameDiff.getVariablesInScope(vertexName); + if(!scopeVars.isEmpty()){ + SDVariable lastVar = null; + for(int i = scopeVars.size() - 1 ; i >= 0 ; i--){ + SDVariable variable = scopeVars.get(i); + + if(!variable.name().contains("/loss/") && !variable.name().endsWith("loss") && !variable.name().endsWith("labels")){ + lastVar = variable; + break; + } + + } + + if(lastVar != null) + sdActivationVariables.put(vertexName, lastVar); + else { + List vertexInputs = graph.getConfiguration().getVertexInputs().get(vertexName); + if(vertexInputs.size() == 1){ + sdActivationVariables.put(vertexName, sdActivationVariables.get(vertexInputs.get(0))); + } + } + } + } + + Map sdActivations = sameDiff.batchOutput() + .inputs(inputsMap) + .output(sdActivationVariables.values().toArray(new SDVariable[0])) + .output(); + + assertEquals("Sizes of DL4J activations and found SameDiff activations differ", activations.size(), sdActivationVariables.size()); + + + List> messages = new ArrayList<>(); + boolean failed = false; + for(String vertexName : activations.keySet()){ + INDArray dl4j = activations.get(vertexName); + INDArray sd = sdActivations.get(sdActivationVariables.get(vertexName).name()); + + if(! sd.equalsWithEps(dl4j, 1e-3)) { + GraphVertex vertex = graph.getVertex(vertexName); + + if(!failed){ + if(vertex instanceof LayerVertex) + failurePointLayers.add(vertex.getLayer().conf().getLayer().getClass().getSimpleName()); + else + failurePointVertices.add(vertex.getClass().getSimpleName()); + } + + failed = true; + + String vertexStr = vertexName + "[" + vertex.getClass().getSimpleName(); + + if(vertex.hasLayer()) + vertexStr += "(" + vertex.getLayer().conf().getLayer().getClass().getSimpleName() + ")"; + + vertexStr += "]"; + + if(FAIL_FAST) + fail("DL4J activation and SameDiff activation not equal for Vertex " + vertexStr + " and SDVariable " + sdActivationVariables.get(vertexName).name()); + else + messages.add(new Pair<>(vertexStr, sdActivationVariables.get(vertexName).name())); + } + + } + + StringBuilder message = new StringBuilder("DL4J activation and SameDiff activation not equal for "); + + for(Pair pair : messages) + message.append("Layer ").append(pair.getFirst()).append(" and SDVariable ").append(pair.getSecond()) + .append(", "); + + assertEquals(message.toString(), 0, messages.size()); + + if(sameDiff.getTrainingConfig() != null && labels != null) { + + for(GraphVertex vertex : graph.getVertices()){ + Stage.Loss.record(vertex); + } + + List labelNames = sameDiff.getTrainingConfig().getDataSetLabelMapping(); + Map inputAndLabelMap = new HashMap<>(inputsMap); + Preconditions.checkArgument(labels.length == labelNames.size(), + "Didn't supply the right number of labels: expected %s, got %s", labelNames.size(), labels.length); + + for (int i = 0; i < labels.length; i++) { + inputAndLabelMap.put(labelNames.get(i), labels[i]); + } + + graph.setLabels(labels); + graph.computeGradientAndScore(); + double score = graph.score() - graph.calcRegularizationScore(true); + + Map sdLosses = sameDiff.batchOutput() + .inputs(inputAndLabelMap) + .output(sameDiff.getLossVariables().toArray(new String[0])) + .output(); + + double sdScore = 0; + for(INDArray scoreArr : sdLosses.values()) + sdScore += scoreArr.sumNumber().doubleValue(); + + Set lossFunctions = new HashSet<>(); + for(String name : graph.getConfiguration().getNetworkOutputs()){ + GraphVertex vertex = graph.getVertex(name); + if(vertex.hasLayer()){ + ILossFunction lossFn = getLossFn(vertex.getLayer()); + if(lossFn != null) + lossFunctions.add(lossFn.getClass().getSimpleName()); + } + } + + try { + assertEquals("Losses don't match for original network and SameDiff version, with loss functions " + lossFunctions, + sdScore, score, 1e-3); + } catch (AssertionError ae){ + if(ae.getMessage().contains("Losses don't match") && !lossFunctions.isEmpty()){ + failureLosses.addAll(lossFunctions); + } + throw ae; + } + } + + if(PRINT_AFTER_EVERY) { + printResults(); + } + } + + private static final Set> foundLayers = findLayers(); + private static final Set> foundLosses = findLosses(); + private static final Set> foundDropouts = findDropouts(); + private static final Set> foundActivations = findActivations(); + private static final Set> foundPreprocessors = findPreprocessors(); + private static final Set> foundVertices = findVertices(); + + public static int printResults() { + int conversion = Stage.Conversion.check(foundLayers, foundLosses, foundDropouts, foundActivations, foundPreprocessors, foundVertices); + int output = Stage.Output.check(foundLayers, foundLosses, foundDropouts, foundActivations, foundPreprocessors, foundVertices); + int loss = Stage.Loss.check(foundLayers, foundLosses, foundDropouts, foundActivations, foundPreprocessors, foundVertices); + + if(!(failurePointVertices.isEmpty() && failureLosses.isEmpty() && failurePointLayers.isEmpty())){ + log.info(" --- ToSameDiff Failure Points --- "); + } + + if(!failurePointLayers.isEmpty()){ + log.info("Failure point layers: {}", failurePointLayers); + } + + if(!failurePointVertices.isEmpty()){ + log.info("Failure point vertices: {}", failurePointVertices); + } + + if(!failureLosses.isEmpty()){ + log.info("Failed losses: {}", failureLosses); + } + + return conversion + output + loss; + } + + @Override + public void testRunFinished(Result result) throws Exception { + int failCount = printResults(); + + if(FAIL_IF_MISSING){ + assertEquals("There were missing ToSameDiff tests", 0, failCount); + } else if(failCount > 0){ + log.warn("There were {} missing ToSameDiff tests", failCount); + } + } +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java index 23c46835e8b7..3fadb25d9b23 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java @@ -28,7 +28,6 @@ import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.SameDiffLoss; import java.io.File; import java.io.InputStream; @@ -46,10 +45,10 @@ public class KerasCustomLossTest extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); - public class LogCosh extends SameDiffLoss { + public class LogCosh extends SameDiffNonFusedLoss { @Override - public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { - return sd.math.log(sd.math.cosh(labels.sub(layerInput))); + public SDVariable defineLossArray(SameDiff sameDiff, SDVariable layerInput, SDVariable labels) { + return sameDiff.math.log(sameDiff.math.cosh(labels.sub(layerInput))); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java index 4c860f7c8c5d..af112c2242c1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java @@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.Layer; import org.nd4j.linalg.api.ndarray.INDArray; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.linalg.lossfunctions.ILossFunction; /** * Interface for output layers (those that calculate gradients with respect to a labels array) @@ -64,6 +65,4 @@ public interface IOutputLayer extends Layer, Classifier { * @return A column INDArray of shape [numExamples,1], where entry i is the score of the ith example */ INDArray computeScoreForExamples(double fullNetworkRegScore, LayerWorkspaceMgr workspaceMgr); - - } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java index 92b98eff5136..1f701f442df8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java @@ -17,8 +17,13 @@ package org.deeplearning4j.nn.conf; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -31,6 +36,9 @@ * for pre processing input before passing it * to the neural network. * + * You will most likely want to extend BaseInputPreProcessor when creating a custom preprocessor, + * as it supplies default exception-throwing define* methods. + * * @author Adam Gibson */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @@ -69,4 +77,15 @@ public interface InputPreProcessor extends Serializable, Cloneable { Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize); + + /** + * Define the InputPreProcessor's input transformation in a {@link SameDiff} instance.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} + * like the default implementation in {@link BaseInputPreProcessor}. + * + * @param sameDiff The {@link SameDiff} instance. + * @param input The input to transform. + * @return The transformed input. + */ + @NonNull SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index 1a61acc4dbd2..0527f116314f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -90,6 +90,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable { //Counter for the number of epochs completed so far. Used for per-epoch schedules protected int epochCount = 0; + @Getter + protected InputType inputType; + public int getEpochCount() { return epochCount; } @@ -715,6 +718,7 @@ public MultiLayerConfiguration build() { conf.inferenceWorkspaceMode = inferenceWorkspaceMode; conf.cacheMode = cacheMode; conf.dataType = dataType; + conf.inputType = inputType; Nd4j.getRandom().setSeed(conf.getConf(0).getSeed()); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index 462bc9f17c90..b4ad88ed4428 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -101,6 +101,8 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { //Counter for the number of epochs completed so far. Used for per-epoch schedules protected int epochCount = 0; +// protected IUpdater iUpdater; + /** * Creates and returns a deep copy of the configuration. @@ -1094,6 +1096,7 @@ public NeuralNetConfiguration build() { conf.miniBatch = miniBatch; conf.cacheMode = this.cacheMode; conf.dataType = this.dataType; +// conf.iUpdater = iUpdater; configureLayer(layer); if (layer instanceof FrozenLayer) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java index aa2ed34781ab..a1c6736718dd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java @@ -55,7 +55,7 @@ @EqualsAndHashCode(exclude = {"lastPValue","alphaPrime","a","b", "mask"}) @ToString(exclude = {"lastPValue","alphaPrime","a","b"}) @JsonIgnoreProperties({"lastPValue", "alphaPrime", "a", "b", "mask"}) -public class AlphaDropout implements IDropout { +public class AlphaDropout extends BaseDropout { public static final double DEFAULT_ALPHA = 1.6732632423543772; public static final double DEFAULT_LAMBDA = 1.0507009873554804; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/BaseDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/BaseDropout.java new file mode 100644 index 000000000000..10298b977ab0 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/BaseDropout.java @@ -0,0 +1,35 @@ +/* + * ****************************************************************************** + * * Copyright (c) 2020 Konduit K.K. + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.nn.conf.dropout; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +public abstract class BaseDropout implements IDropout { + @Override + public SDVariable defineDropout(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } + + @Override + public IDropout clone() { + throw new UnsupportedOperationException("Clone not implemented for " + this.getClass().getSimpleName()); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index acb6afa2c8ee..7e8c3e367972 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -19,10 +19,13 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -69,7 +72,7 @@ @JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"}) @EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"}) @Slf4j -public class Dropout implements IDropout { +public class Dropout extends BaseDropout { /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? @@ -242,6 +245,14 @@ public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iterat return gradAtInput; } + @Override + public SDVariable defineDropout(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + if(pSchedule != null) + throw new UnsupportedOperationException("Scheduled dropout is not supported for SameDiff conversion"); + + return sameDiff.nn.dropout(input, p); + } + @Override public void clear() { mask = null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java index cd25718bc77a..7c3015cc9205 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java @@ -49,7 +49,7 @@ @Data @JsonIgnoreProperties({"noise"}) @EqualsAndHashCode(exclude = {"noise"}) -public class GaussianDropout implements IDropout { +public class GaussianDropout extends BaseDropout { private final double rate; private final ISchedule rateSchedule; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java index d165614abca7..319c720cbd57 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java @@ -33,7 +33,7 @@ * @author Alex Black */ @Data -public class GaussianNoise implements IDropout { +public class GaussianNoise extends BaseDropout { private double stddev; private ISchedule stddevSchedule; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java index 43ba15898a0d..2e3d7a9099aa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java @@ -16,7 +16,10 @@ package org.deeplearning4j.nn.conf.dropout; +import lombok.NonNull; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -59,4 +62,15 @@ public interface IDropout extends Serializable, Cloneable { void clear(); IDropout clone(); + + /** + * Define the dropout for a {@link SameDiff} instance.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} + * like the default implementation in {@link BaseDropout}. + * + * @param sameDiff The {@link SameDiff} instance + * @param input The input to the dropout, typically the output of the previous layer. + * @return The score (loss function value). + */ + @NonNull SDVariable defineDropout(@NonNull SameDiff sameDiff, @NonNull SDVariable input); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java index 4d83e7fe8701..e9fd147c9e40 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java @@ -46,7 +46,7 @@ @Data @JsonIgnoreProperties({"mask"}) @EqualsAndHashCode(exclude = {"mask"}) -public class SpatialDropout implements IDropout { +public class SpatialDropout extends BaseDropout { private double p; private ISchedule pSchedule; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java index 65515153d15d..f2d0721c9a62 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java @@ -18,6 +18,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeRecurrent; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -79,6 +80,15 @@ public String toString() { return "StackVertex()"; } + private boolean compatibleInputTypes(InputType original, InputType other){ + if(original instanceof InputTypeRecurrent && other instanceof InputTypeRecurrent){ + return ((InputTypeRecurrent) original).getFormat().equals(((InputTypeRecurrent) other).getFormat()) && + ((InputTypeRecurrent) original).getSize() == ((InputTypeRecurrent) other).getSize(); + } else { + return original.equals(other); + } + } + @Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { if (vertexInputs.length == 1) @@ -87,11 +97,12 @@ public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws //Check that types are all the same... for( int i=1; i paramTable) { + return activationFn.defineActivation(sameDiff, layerInput); + } + @Override public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java index 7abe0da061f6..57395d7bb3ca 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java @@ -25,6 +25,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.util.NetworkUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.learning.config.IUpdater; @@ -145,6 +147,16 @@ public List getRegularizationByParam(String paramName){ return null; } + /** + * Applies the activation function if it isn't null. + */ + protected SDVariable doActivation(@NonNull SDVariable input){ + if(activationFn != null) + return activationFn.defineActivation(input.getSameDiff(), input); + else + return input; + } + @SuppressWarnings("unchecked") @Getter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java index e8de58d9fa7f..8774f08284ed 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java @@ -20,18 +20,17 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; -import org.nd4j.linalg.lossfunctions.impl.LossMSE; -import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; @Data @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public abstract class BaseOutputLayer extends FeedForwardLayer { +public abstract class BaseOutputLayer extends FeedForwardLayer implements LayerWithLoss { protected ILossFunction lossFn; protected boolean hasBias = true; @@ -79,6 +78,11 @@ public LayerMemoryReport getMemoryReport(InputType inputType) { .build(); } + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + throw new UnsupportedOperationException("SameDiff loss conversion has not been implemented for " + this.getClass().getSimpleName()); + } @Getter @Setter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java index 0b98dfad9d32..e10b28cf3ee5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java @@ -16,18 +16,22 @@ package org.deeplearning4j.nn.conf.layers; +import java.util.Map; import lombok.*; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional.Mode; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import java.util.Arrays; import java.util.List; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; @Data @NoArgsConstructor @@ -44,6 +48,27 @@ protected BaseRecurrentLayer(Builder builder) { this.rnnDataFormat = builder.rnnDataFormat; } + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + @NonNull Map paramTable, SDVariable mask, boolean backwards){ + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return defineLayer(sameDiff, layerInput, paramTable, mask, false); + } + + /** + * An optional method to implement that if implemented, defines the bidirectional operation as a single pass. + * If not defined, should throw a {@link UnsupportedOperationException}, in which case the forward and backward + * passes are done seperatly and combined. + */ + public SDVariable defineBidirectional(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + @NonNull Map paramTable, SDVariable mask, Mode mode) { + throw new UnsupportedOperationException("Bidirectional toSameDiff not supported for " + this.getClass().getSimpleName()); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index dcced3aeb2bc..3809421095a3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -30,8 +30,11 @@ import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.regularization.Regularization; @@ -108,6 +111,55 @@ public ParamInitializer initializer() { return BatchNormalizationParamInitializer.getInstance(); } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + if(lockGammaBeta) + throw new UnsupportedOperationException("Locked Gamma & Beta not supported for SameDiff conversion"); + if(useLogStd) + throw new UnsupportedOperationException("LogStd not supported for SameDiff conversion"); + + INDArray beta = params.get(BatchNormalizationParamInitializer.BETA); + INDArray gamma = params.get(BatchNormalizationParamInitializer.GAMMA); + INDArray mean = params.get(BatchNormalizationParamInitializer.GLOBAL_MEAN); + INDArray variance = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR); + + params.put(BatchNormalizationParamInitializer.BETA, Nd4j.squeeze(beta, 0)); + params.put(BatchNormalizationParamInitializer.GAMMA, Nd4j.squeeze(gamma, 0)); + params.put(BatchNormalizationParamInitializer.GLOBAL_MEAN, Nd4j.squeeze(mean, 0)); + params.put(BatchNormalizationParamInitializer.GLOBAL_VAR, Nd4j.squeeze(variance, 0)); + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + if(lockGammaBeta) + throw new UnsupportedOperationException("Locked Gamma & Beta not supported for SameDiff conversion"); + if(useLogStd) + throw new UnsupportedOperationException("LogStd not supported for SameDiff conversion"); + + SDVariable beta = paramTable.get(BatchNormalizationParamInitializer.BETA); + SDVariable gamma = paramTable.get(BatchNormalizationParamInitializer.GAMMA); + SDVariable mean = paramTable.get(BatchNormalizationParamInitializer.GLOBAL_MEAN); + SDVariable variance = paramTable.get(BatchNormalizationParamInitializer.GLOBAL_VAR); + + int axis; + if(cnn2DFormat == CNN2DFormat.NCHW) + axis = 1; + else if(cnn2DFormat == CNN2DFormat.NHWC) + axis = 3; + else + throw new UnsupportedOperationException("Unknown CNN data format " + cnn2DFormat); + + SDVariable output = sameDiff.nn.batchNorm(layerInput, + mean, + variance, + gamma, + beta, + eps, + axis); + return doActivation(output); + } @Override public InputType getOutputType(int layerIndex, InputType inputType) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java index 5769f2a83e24..6156e11b58a8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java @@ -99,7 +99,7 @@ public void setNIn(InputType inputType, boolean override) { } @Override - public SDVariable defineLayer(SameDiff sd, SDVariable input, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sd, SDVariable input, SDVariable mask, Map paramTable) { // input: [mb, inputCapsules, inputCapsuleDimensions] diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java index 34213038a918..13028b263f98 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java @@ -24,7 +24,10 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.CenterLossParamInitializer; +import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java index 1bde3d912e47..d7f588a8d808 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java @@ -22,13 +22,17 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; import java.util.Collection; @@ -59,7 +63,7 @@ @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public class Cnn3DLossLayer extends FeedForwardLayer { +public class Cnn3DLossLayer extends FeedForwardLayer implements LayerWithLoss { protected ILossFunction lossFn; protected Convolution3D.DataFormat dataFormat; @@ -89,6 +93,49 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + SDVariable batch = sameDiff.sizeAt(layerInput, 0); + SDVariable channels; + SDVariable depth; + SDVariable height; + SDVariable width; + + if(dataFormat == DataFormat.NCDHW){ + channels = sameDiff.sizeAt(layerInput, 1); + depth = sameDiff.sizeAt(layerInput, 2); + height = sameDiff.sizeAt(layerInput, 3); + width = sameDiff.sizeAt(layerInput, 4); + layerInput = layerInput.permute(0, 2, 3, 4, 1); + } else if(dataFormat == DataFormat.NDHWC){ + depth = sameDiff.sizeAt(layerInput, 1); + height = sameDiff.sizeAt(layerInput, 2); + width = sameDiff.sizeAt(layerInput, 3); + channels = sameDiff.sizeAt(layerInput, 4); + } else + throw new UnsupportedOperationException("Unknown CNN 3D data format " + dataFormat); + + SDVariable distributedInput = layerInput.reshape(sameDiff.concat(0, sameDiff.constant( + Nd4j.scalar(batch.dataType(), -1)), channels)); + + SDVariable distributedOutput = doActivation(distributedInput); + + SDVariable output = distributedOutput.reshape(sameDiff.concat(0, batch, depth, height, width, channels)); + + if(dataFormat == DataFormat.NCDHW) + return output.permute(0, 4, 1, 2, 3); + else + return output; + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, labels, average); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.CNN3D diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index 647b187e38ec..617b05e6221f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -30,9 +31,12 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -60,7 +64,7 @@ @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public class CnnLossLayer extends FeedForwardLayer { +public class CnnLossLayer extends FeedForwardLayer implements LayerWithLoss { protected ILossFunction lossFn; protected CNN2DFormat format = CNN2DFormat.NCHW; @@ -90,6 +94,46 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + SDVariable batch = sameDiff.sizeAt(layerInput, 0); + SDVariable channels; + SDVariable height; + SDVariable width; + + if(format == CNN2DFormat.NCHW){ + channels = sameDiff.sizeAt(layerInput, 1); + height = sameDiff.sizeAt(layerInput, 2); + width = sameDiff.sizeAt(layerInput, 3); + layerInput = layerInput.permute(0, 2, 3, 1); + } else if(format == CNN2DFormat.NHWC){ + height = sameDiff.sizeAt(layerInput, 1); + width = sameDiff.sizeAt(layerInput, 2); + channels = sameDiff.sizeAt(layerInput, 3); + } else + throw new UnsupportedOperationException("Unknown CNN data format " + format); + + SDVariable distributedInput = layerInput.reshape(sameDiff.concat(0, sameDiff.constant( + Nd4j.scalar(batch.dataType(), -1)), channels)); + + SDVariable distributedOutput = doActivation(distributedInput); + + SDVariable output = distributedOutput.reshape(sameDiff.concat(0, batch, height, width, channels)); + + if(format == CNN2DFormat.NCHW) + return output.permute(0, 3, 1, 2); + else + return output; + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, labels, average); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.CNN diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index f4d247670f53..8c38266a9d41 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -19,20 +19,28 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; +import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode; +import org.nd4j.linalg.factory.Nd4j; /** * 1D (temporal) convolutional layer. This layer accepts RNN InputTypes instead of CNN InputTypes @@ -77,6 +85,41 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + INDArray weight = params.get(ConvolutionParamInitializer.WEIGHT_KEY); + params.put(ConvolutionParamInitializer.WEIGHT_KEY, Nd4j.squeeze(weight, 3).permute(2, 1, 0)); + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); + + PaddingMode paddingMode; + + if(convolutionMode == ConvolutionMode.Same) + paddingMode = PaddingMode.SAME; + else if(convolutionMode == ConvolutionMode.Causal) + paddingMode = PaddingMode.CAUSAL; + else + paddingMode = PaddingMode.VALID; + + SDVariable value = sameDiff.cnn.conv1d(layerInput, weight, bias, + Conv1DConfig.builder() + .dataFormat(rnnDataFormat.name()) + .paddingMode(paddingMode) + .k(kernelSize[0]) + .s(stride[0]) + .p(padding[0]) + .d(dilation[0]) + .build() + ); + + return doActivation(value); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index dc88116e5de5..3d70dbc4e338 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.conf.layers; +import java.util.HashMap; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -29,11 +30,14 @@ import org.deeplearning4j.util.Convolution3DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; /** * 3D convolution layer configuration @@ -113,6 +117,33 @@ public ParamInitializer initializer() { return Convolution3DParamInitializer.getInstance(); } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + INDArray weight = params.get(Convolution3DParamInitializer.WEIGHT_KEY); + params.put(Convolution3DParamInitializer.WEIGHT_KEY, weight.permute(2, 3, 4, 1, 0)); + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(Convolution3DParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(Convolution3DParamInitializer.BIAS_KEY); + + SDVariable value = sameDiff.cnn.conv3d(layerInput, weight, bias, + Conv3DConfig.builder() + .dataFormat(this.dataFormat.name()) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .kD(kernelSize[0]).kH(kernelSize[1]).kW(kernelSize[2]) + .sD(stride[0]).sH(stride[1]).sW(stride[2]) + .pD(padding[0]).pH(padding[1]).pW(padding[2]) + .dD(dilation[0]).dH(dilation[1]).dW(dilation[2]) + .biasUsed(hasBias) + .build() + ); + + return doActivation(value); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 3b2d4c0befe7..c7cad4e0e07a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -16,10 +16,25 @@ package org.deeplearning4j.nn.conf.layers; -import lombok.*; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.Setter; +import lombok.ToString; +import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -27,14 +42,13 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; +import org.nd4j.enums.WeightsFormat; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; /** * 2D Convolution layer (for example, spatial convolution over images). Input activations should be format {@code @@ -184,6 +198,27 @@ public ParamInitializer initializer() { return ConvolutionParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); + + SDVariable value = sameDiff.cnn.conv2d(layerInput, weight, bias, + Conv2DConfig.builder() + .dataFormat(this.cnn2dDataFormat.name()) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .kH(kernelSize[0]).kW(kernelSize[1]) + .sH(stride[0]).sW(stride[1]) + .pH(padding[0]).pW(padding[1]) + .dH(dilation[0]).dW(dilation[1]) + .weightsFormat(WeightsFormat.OIYX) + .build() + ); + + return doActivation(value); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index b2f64c89450b..30a07c4ee5d6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -30,11 +31,14 @@ import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; /** * 2D deconvolution layer configuration
@@ -108,6 +112,32 @@ public ParamInitializer initializer() { return DeconvolutionParamInitializer.getInstance(); } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + INDArray weight = params.get(DeconvolutionParamInitializer.WEIGHT_KEY); + params.put(DeconvolutionParamInitializer.WEIGHT_KEY, weight.permute(2, 3, 1, 0)); + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(DeconvolutionParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(DeconvolutionParamInitializer.BIAS_KEY); + + SDVariable value = sameDiff.cnn.deconv2d(layerInput, weight, bias, + DeConv2DConfig.builder() + .dataFormat(this.cnn2dDataFormat.name()) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .kH(kernelSize[0]).kW(kernelSize[1]) + .sH(stride[0]).sW(stride[1]) + .pH(padding[0]).pW(padding[1]) + .dH(dilation[0]).dW(dilation[1]) + .build() + ); + + return doActivation(value); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java index 01bd3ca832c0..73111356644f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -26,17 +27,19 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer; import org.deeplearning4j.nn.layers.convolution.Deconvolution3DLayer; +import org.deeplearning4j.nn.params.Convolution3DParamInitializer; import org.deeplearning4j.nn.params.Deconvolution3DParamInitializer; -import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; /** * 3D deconvolution layer configuration
@@ -110,6 +113,26 @@ public ParamInitializer initializer() { return Deconvolution3DParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(Convolution3DParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(Convolution3DParamInitializer.BIAS_KEY); + + SDVariable value = sameDiff.cnn.deconv3d(layerInput, weight, bias, + DeConv3DConfig.builder() + .dataFormat(this.dataFormat.name()) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .kD(kernelSize[0]).kH(kernelSize[1]).kW(kernelSize[2]) + .sD(stride[0]).sH(stride[1]).sW(stride[2]) + .pD(padding[0]).pH(padding[1]).pW(padding[2]) + .dD(dilation[0]).dH(dilation[1]).dW(dilation[2]) + .build() + ); + + return doActivation(value); + } + @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType == null) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java index 67cac076d11b..b5d3061fef5d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java @@ -25,6 +25,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -72,6 +74,27 @@ public ParamInitializer initializer() { return DefaultParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + SDVariable weight = paramTable.get(DefaultParamInitializer.WEIGHT_KEY); + // may be null + SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY); + + SDVariable temp = layerInput.mmul(weight); + + if(hasLayerNorm()){ + SDVariable gain = paramTable.get(DefaultParamInitializer.GAIN_KEY); + temp = sameDiff.nn.layerNorm(temp, gain, bias, false, 1); + } + + if(hasBias()) + temp = temp.add(bias); + + return doActivation(temp); + } + @Override public LayerMemoryReport getMemoryReport(InputType inputType) { InputType outputType = getOutputType(-1, inputType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index 7edf65618dd1..5fb77cd298e0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.DepthwiseConvolution2DLayer; @@ -27,11 +28,15 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; +import org.nd4j.enums.WeightsFormat; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; /** * 2D depth-wise convolution layer configuration. @@ -89,6 +94,31 @@ public ParamInitializer initializer() { return DepthwiseConvolutionParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(DepthwiseConvolutionParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(DepthwiseConvolutionParamInitializer.BIAS_KEY); + + if(depthMultiplier != 1) + throw new UnsupportedOperationException("Can't convert depthwise convolutions wih a depth multiplier != 1"); + + //TODO can't set depthMultiplier? + SDVariable value = sameDiff.cnn.depthWiseConv2d(layerInput, weight, bias, + Conv2DConfig.builder() + .dataFormat(this.cnn2dDataFormat.name()) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .kH(kernelSize[0]).kW(kernelSize[1]) + .sH(stride[0]).sW(stride[1]) + .pH(padding[0]).pW(padding[1]) + .dH(dilation[0]).dW(dilation[1]) + .weightsFormat(WeightsFormat.OIYX) + .build() + ); + + return doActivation(value); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java index 94cad0a9804f..e3310db142e0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java @@ -27,6 +27,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -84,6 +86,12 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return doActivation(layerInput); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java index 6478b6d59b67..4b1bb8550970 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java @@ -23,13 +23,14 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -83,6 +84,25 @@ public ParamInitializer initializer() { return EmbeddingLayerParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { +// SDVariable weight = paramTable.get(EmbeddingLayerParamInitializer.WEIGHT_KEY); +// SDVariable bias = paramTable.get(EmbeddingLayerParamInitializer.BIAS_KEY); +// +// TODO this cast causes a JVM crash +// SDVariable indices = sameDiff.squeeze(layerInput, 1).castTo(DataType.INT64); +// +// System.out.println("Here!"); +// SDVariable out = sameDiff.gather(weight, indices, 1); +// +// if(hasBias) +// out = out.add(bias); +// +// return doActivation(out); + throw new UnsupportedOperationException("Can't convert EmbeddingLayer to SameDiff"); + } + @Override public LayerMemoryReport getMemoryReport(InputType inputType) { //Basically a dense layer, but no dropout is possible here, and no epsilons diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 1b76b6c7b7e4..d7e6b99c9431 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -31,6 +31,8 @@ import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java index d9e10e6f5161..53b1c0caa121 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java @@ -28,6 +28,8 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -106,6 +108,7 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java index ba0b52d8c555..7a01b4b9c223 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java @@ -16,29 +16,52 @@ package org.deeplearning4j.nn.conf.layers; -import lombok.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional.Mode; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers; import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationELU; +import org.nd4j.linalg.activations.impl.ActivationHardSigmoid; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.activations.impl.ActivationSoftPlus; +import org.nd4j.linalg.activations.impl.ActivationSoftSign; +import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.activations.impl.ActivationThresholdedReLU; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; /** * LSTM recurrent neural network layer without peephole connections. Supports CuDNN acceleration - see https://deeplearning4j.konduit.ai/config/backends/config-cudnn for details + * href="https://deeplearning4j.konduit.ai/config/backends/config-cudnn">https://deeplearning4j.konduit.ai/config/backends/config-cudnn + * for details * * @author Alex Black * @see GravesLSTM GravesLSTM class for an alternative LSTM (with peephole connections) @@ -76,9 +99,10 @@ protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Bui @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, - int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("LSTM", getLayerName(), layerIndex, getNIn(), getNOut()); - org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf, networkDataType); + org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf, + networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); @@ -93,6 +117,141 @@ public ParamInitializer initializer() { return LSTMParamInitializer.getInstance(); } + private static LSTMActivations toLSTMActivation(IActivation activationFn){ + if(activationFn instanceof ActivationTanH) + return LSTMActivations.TANH; + else if(activationFn instanceof ActivationReLU) { + ActivationReLU relu = (ActivationReLU) activationFn; + if(relu.getThreshold() != 0 || relu.getNegativeSlope() != 0) + throw new UnsupportedOperationException("LSTM toSameDiff doesn't support ReLU activation with threshold and negative slope."); + + if(relu.getMax() != 0) + throw new UnsupportedOperationException("LSTM toSameDiff doesn't support ReLU activation with max."); + + //TODO no way to pass parms to libnd4j +// if(relu.getNegativeSlope() != 0) +// return LSTMActivations.LEAKY_RELU; +// +// if(relu.getThreshold() != 0) +// return LSTMActivations.THRESHHOLD_RELU; + + return LSTMActivations.RELU; + } else if(activationFn instanceof ActivationSigmoid) + return LSTMActivations.SIGMOID; + else if(activationFn instanceof ActivationLReLU) +// return LSTMActivations.LEAKY_RELU; + //TODO no way to pass parms to libnd4j + throw new UnsupportedOperationException("LSTM toSameDiff doesn't support activation ActivationLReLU"); + else if(activationFn instanceof ActivationThresholdedReLU) +// return LSTMActivations.THRESHHOLD_RELU; + //TODO no way to pass parms to libnd4j + throw new UnsupportedOperationException("LSTM toSameDiff doesn't support activation ActivationThresholdedReLU"); + else if(activationFn instanceof ActivationHardSigmoid) + return LSTMActivations.HARD_SIGMOID; + else if(activationFn instanceof ActivationELU) + return LSTMActivations.ELU; + else if(activationFn instanceof ActivationSoftSign) + return LSTMActivations.SOFTSIGN; + else if(activationFn instanceof ActivationSoftPlus) + return LSTMActivations.SOFTPLUS; + else + //TODO add ActivationThresholdedReLU and ActivationLReLU to list once supported + throw new UnsupportedOperationException("Unsupported activation for LSTM toSameDiff: " + activationFn.getClass().getSimpleName() + + ". Should be one of ActivationTanH, ActivationReLU, ActivationSigmoid, " + + "ActivationHardSigmoid, ActivationELU, ActivationSoftSign, or ActivationSoftPlus."); + } + + /** + * Change weight from [input, forget, output, gate] (DL4J) to [input, forget, gate, output] (SameDiff) + */ + private static INDArray changeWeight(INDArray weight){ + int size = (int) (weight.size(1) / 4); + INDArray input = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size)); + INDArray forget = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2*size)); + INDArray output = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(2*size, 3*size)); + INDArray gate = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(3*size, 4*size)); + return Nd4j.concat(1, input, forget, gate, output); + } + + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + INDArray bias = params.get(LSTMParamInitializer.BIAS_KEY); + params.put(LSTMParamInitializer.BIAS_KEY, Nd4j.squeeze(bias, 0)); + + params.put(LSTMParamInitializer.INPUT_WEIGHT_KEY, + changeWeight(params.get(LSTMParamInitializer.INPUT_WEIGHT_KEY))); + + params.put(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, + changeWeight(params.get(LSTMParamInitializer.RECURRENT_WEIGHT_KEY))); + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + SDVariable recurrentWeight = paramTable.get(LSTMParamInitializer.RECURRENT_WEIGHT_KEY); + SDVariable inputWeight = paramTable.get(LSTMParamInitializer.INPUT_WEIGHT_KEY); + SDVariable bias = paramTable.get(LSTMParamInitializer.BIAS_KEY); + + LSTMActivations gateActivation = toLSTMActivation(gateActivationFn); + LSTMActivations recurrentActivation = toLSTMActivation(activationFn); + + + return sameDiff.rnn.lstmLayer(layerInput, LSTMLayerWeights.builder() + .weights(inputWeight) + .rWeights(recurrentWeight) + .bias(bias) + .build(), + LSTMLayerConfig.builder() + .gateAct(gateActivation) + .cellAct(recurrentActivation) + .outAct(recurrentActivation) + .retFullSequence(true) + .directionMode(LSTMDirectionMode.FWD) + .lstmdataformat(rnnDataFormat == RNNFormat.NCW ? LSTMDataFormat.NST : LSTMDataFormat.NTS) + .build())[0]; + } + + @Override + public SDVariable defineBidirectional(SameDiff sameDiff, SDVariable layerInput, Map paramTable, + SDVariable mask, Mode mode) { + //TODO need different param transforms for bidirectional +// SDVariable recurrentWeight = paramTable.get(LSTMParamInitializer.RECURRENT_WEIGHT_KEY); +// SDVariable inputWeight = paramTable.get(LSTMParamInitializer.INPUT_WEIGHT_KEY); +// SDVariable bias = paramTable.get(LSTMParamInitializer.BIAS_KEY); +// +// LSTMActivations gateActivation = toLSTMActivation(gateActivationFn); +// LSTMActivations recurrentActivation = toLSTMActivation(activationFn); +// +// LSTMDirectionMode directionMode; +// if(mode == Mode.ADD || mode == Mode.AVERAGE) +// directionMode = LSTMDirectionMode.BIDIR_SUM; +// else if(mode == Mode.CONCAT) +// directionMode = LSTMDirectionMode.BIDIR_CONCAT; +// else +// throw new UnsupportedOperationException("Bidirectional not supported for mode " + mode); +// +// LSTMDataFormat format = rnnDataFormat == RNNFormat.NCW ? LSTMDataFormat.NST : LSTMDataFormat.NTS; +// +// SDVariable output = sameDiff.rnn.lstmLayer(layerInput, LSTMLayerWeights.builder() +// .weights(inputWeight) +// .rWeights(recurrentWeight) +// .bias(bias) +// .build(), +// LSTMLayerConfig.builder() +// .gateAct(gateActivation) +// .cellAct(recurrentActivation) +// .directionMode(directionMode) +// .lstmdataformat(format) +// .build())[0]; +// +// if(mode == Mode.AVERAGE) +// return output.div(2); +// else +// return output; + return super.defineBidirectional(sameDiff, layerInput, paramTable, mask, mode); + } + @Override public LayerMemoryReport getMemoryReport(InputType inputType) { //TODO - CuDNN etc diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java index 25577bd1fad6..7f7c09030ed3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.TrainingConfig; @@ -30,6 +31,8 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; @@ -97,6 +100,36 @@ protected void initializeConstraints(Builder builder) { this.iDropout = builder.iDropout; } + + /** + * Define the layer for SameDiff conversion.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} like it does if not overridden. + * + * @param sameDiff SameDiff instance + * @param layerInput Input to the layer + * @param mask Optional, maybe null. Mask to apply if supported + * @param paramTable Parameter table - keys and shapes as defined in the layer implementation class. + * @return The final layer variable corresponding to the activations/output from the forward pass + */ + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, SDVariable mask, + @NonNull Map paramTable){ + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } + + /** + * Do any necessary transforms to parameters (weights, biases, etc) before making SDVariables out of them. + * Useful for things like changing the dimension order or squeezing. + * + * Adding or removing parameters is supported.
+ * + * Should throw a {@link UnsupportedOperationException} if conversion of this layer configuration isn't + * supported and it will cause an error when transforming weights. + * + * @param params The parameters of the layer. + */ + public void transformParamsForSameDiff(@NonNull Map params){ + } + /** * Reset the learning related configs of the layer to default. When instantiated with a global * neural network configuration the parameters specified in the neural network configuration diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerWithLoss.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerWithLoss.java new file mode 100644 index 000000000000..6259b69ec923 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerWithLoss.java @@ -0,0 +1,42 @@ +/* + * ****************************************************************************** + * * Copyright (c) 2020 Konduit K.K. + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.nn.conf.layers; + +import lombok.NonNull; +import org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +/** + * Any loss or output layers that support SameDiff conversion must implement this. + */ +public interface LayerWithLoss { + + /** + * Define the layer's loss function. Should return a scalar.
+ * + * If average is true, should be the batchwise average, otherwise the sum. + * @param sameDiff The {@link SameDiff} instance + * @param input The input to the loss function, the output (activations) of this layer. + * @param labels The labels to compare the output to. The placeholder will be created with the shape of the output (activations) of this layer. May be null if the implementation layer doesn't require labels (e.g. {@link OCNNOutputLayer}. + * @param average Whether to average the loss per example. Most of the time this should be passed to the {@link org.nd4j.linalg.lossfunctions.ILossFunction}. + * @return The loss scalar. + */ + SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, boolean average); +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java index 8ddd20b4558e..3d210675f1b8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java @@ -24,7 +24,6 @@ import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; import org.deeplearning4j.nn.weights.WeightInitUtil; -import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; @@ -149,9 +148,10 @@ public void initializeParameters(Map params) { @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { val baseQueries = paramTable.get(WEIGHT_QUERIES); - val batchSize = layerInput.shape().get(SDIndex.point(0)); + val batchSize = sameDiff.sizeAt(layerInput, 0); val tileAxis = sameDiff.scatterUpdate(sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize); val queries = sameDiff.tile(baseQueries, tileAxis); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index ebfc56a7b1ae..5bf17f72776d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -27,9 +27,12 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; import org.nd4j.linalg.learning.regularization.Regularization; import java.util.Collection; @@ -90,6 +93,25 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + if(dataFormat == CNN2DFormat.NHWC) + layerInput = layerInput.permute(0, 3, 1, 2); + //TODO support more data types + + SDVariable output = sameDiff.cnn.localResponseNormalization(layerInput, LocalResponseNormalizationConfig.builder() + .alpha(alpha) + .beta(beta) //TODO n and k map to bias and depth? guessing based on the paper but data types don't line up + .bias(k) + .depth((int) n) + .build()); + + if(dataFormat == CNN2DFormat.NHWC) + output = output.permute(0, 2, 3, 1); + + return output; + } @Override public InputType getOutputType(int layerIndex, InputType inputType) { @@ -274,7 +296,7 @@ public Builder helperAllowFallback(boolean allowFallback) { * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). * See {@link CNN2DFormat} for more details.
* Default: NCHW - * @param format Format for activations (in and out) + * @param dataFormat Format for activations (in and out) */ public Builder dataFormat(CNN2DFormat dataFormat){ this.dataFormat = dataFormat; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index a5028f08a4f0..c829abb6d65b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -169,7 +169,8 @@ public void initializeParameters(Map params) { } @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); // (outH, featureDim, nOut) int outH = outputSize; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index b65d2fe77c72..e86b0ad3f74b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -173,7 +173,8 @@ public void initializeParameters(Map params) { } @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java index e26dbd83fea1..269a17d8caea 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -28,6 +29,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,7 +52,7 @@ @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public class LossLayer extends FeedForwardLayer { +public class LossLayer extends FeedForwardLayer implements LayerWithLoss { protected ILossFunction lossFn; @@ -90,6 +93,18 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return doActivation(layerInput); + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, labels, average); + } + public static class Builder extends BaseOutputLayer.Builder { public Builder() { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java index 75d86460598d..3a90105a6e8e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java @@ -19,12 +19,15 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -66,6 +69,28 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection paramTable) { + + SDVariable weight = paramTable.get(DefaultParamInitializer.WEIGHT_KEY); + // may be null + SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY); + + SDVariable temp = layerInput.mmul(weight); + + if(hasBias()) + temp = temp.add(bias); + + return doActivation(temp); + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, labels, average); + } + @Override public ParamInitializer initializer() { return DefaultParamInitializer.getInstance(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index b9b8bca4de49..c432b5a0f081 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -27,6 +27,9 @@ import org.deeplearning4j.nn.params.PReLUParamInitializer; import org.deeplearning4j.nn.weights.WeightInitConstant; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.util.ArrayUtil; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -74,6 +77,13 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection paramTable) { + SDVariable alpha = paramTable.get(PReLUParamInitializer.WEIGHT_KEY); + return doActivation(sameDiff.nn.prelu(layerInput, alpha, ArrayUtil.toInts(sharedAxes))); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java index ff4e1cf76b55..3d2cc9e7d704 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java @@ -101,7 +101,7 @@ public PrimaryCapsules(Builder builder){ } @Override - public SDVariable defineLayer(SameDiff SD, SDVariable input, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff SD, SDVariable input, SDVariable mask, Map paramTable) { Conv2DConfig conf = Conv2DConfig.builder() .kH(kernelSize[0]).kW(kernelSize[1]) .sH(stride[0]).sW(stride[1]) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index 10659f326154..4e0d4f228070 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -181,7 +181,8 @@ public void validateInput(INDArray input) { } @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { final val W = paramTable.get(WEIGHT_KEY); final val R = paramTable.get(RECURRENT_WEIGHT_KEY); final val b = paramTable.get(BIAS_KEY); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index f1dcd73a615b..05248153b5dc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -30,8 +31,11 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -53,7 +57,7 @@ @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public class RnnLossLayer extends FeedForwardLayer { +public class RnnLossLayer extends FeedForwardLayer implements LayerWithLoss { private RNNFormat rnnDataFormat = RNNFormat.NCW; protected ILossFunction lossFn; @@ -82,6 +86,43 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + SDVariable batch = sameDiff.sizeAt(layerInput, 0); + SDVariable channels; + SDVariable size; + + if(rnnDataFormat == RNNFormat.NCW){ + channels = sameDiff.sizeAt(layerInput, 1); + size = sameDiff.sizeAt(layerInput, 2); + layerInput = layerInput.permute(0, 2, 1); + } else if(rnnDataFormat == RNNFormat.NWC){ + size = sameDiff.sizeAt(layerInput, 1); + channels = sameDiff.sizeAt(layerInput, 2); + } else + throw new UnsupportedOperationException("Unknown CNN data format " + rnnDataFormat); + + SDVariable distributedInput = layerInput.reshape(sameDiff.concat(0, sameDiff.constant( + Nd4j.scalar(batch.dataType(), -1)), channels)); + + SDVariable distributedOutput = doActivation(distributedInput); + + SDVariable output = distributedOutput.reshape(sameDiff.concat(0, batch, size, channels)); + + if(rnnDataFormat == RNNFormat.NCW) + return output.permute(0, 2, 1); + else + return output; + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, labels, average); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index cfd337514a43..05d0627c82d8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; @@ -28,9 +29,12 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @@ -80,6 +84,48 @@ public ParamInitializer initializer() { return DefaultParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable b = paramTable.get(DefaultParamInitializer.BIAS_KEY); + SDVariable W = paramTable.get(DefaultParamInitializer.WEIGHT_KEY); + + SDVariable batch = sameDiff.sizeAt(layerInput, 0); + SDVariable sequenceLength; + + SDVariable neg1 = sameDiff.constant(Nd4j.scalar(batch.dataType(), -1)); + + if(rnnDataFormat == RNNFormat.NCW) { + sequenceLength = sameDiff.sizeAt(layerInput, 2); + layerInput = layerInput.permute(0, 2, 1); + } else if(rnnDataFormat == RNNFormat.NWC) + sequenceLength = sameDiff.sizeAt(layerInput, 1); + else + throw new UnsupportedOperationException("Unknown RNN data format " + rnnDataFormat); + + SDVariable distributedShape = sameDiff.concat(0, batch.mul(sequenceLength), neg1); + SDVariable distributedInput = layerInput.reshape(distributedShape); + + SDVariable distributedOutput = distributedInput.mmul(W); + if(hasBias) + distributedOutput = distributedOutput.add(b); + + distributedOutput = doActivation(distributedOutput); + + SDVariable temp = distributedOutput.reshape(sameDiff.concat(0, batch, sequenceLength, neg1)); + + if(rnnDataFormat == RNNFormat.NCW) + return temp.permute(0, 2, 1); + else + return temp; + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, labels, average); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java index 79fa765a4984..ca5737b04eaf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java @@ -136,7 +136,8 @@ public void initializeParameters(Map params) { @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { if(projectInput){ val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION); val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index f9ae11b4936e..eec5431bc34a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer; @@ -28,10 +29,14 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.WeightsFormat; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; /** * 2D Separable convolution layer configuration. @@ -149,6 +154,28 @@ public ParamInitializer initializer() { return SeparableConvolutionParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable depthWeight = paramTable.get(SeparableConvolutionParamInitializer.DEPTH_WISE_WEIGHT_KEY); + SDVariable pointWeight = paramTable.get(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY); + SDVariable bias = paramTable.get(SeparableConvolutionParamInitializer.BIAS_KEY); + + SDVariable value = sameDiff.cnn.separableConv2d(layerInput, depthWeight, pointWeight, bias, + Conv2DConfig.builder() + .dataFormat(this.cnn2dDataFormat.name()) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .kH(kernelSize[0]).kW(kernelSize[1]) + .sH(stride[0]).sW(stride[1]) + .pH(padding[0]).pW(padding[1]) + .dH(dilation[0]).dW(dilation[1]) + .weightsFormat(WeightsFormat.OIYX) + .build() + ); + + return doActivation(value); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java index 042f09121a6d..98d676f4a838 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java @@ -27,6 +27,8 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -124,6 +126,13 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + //TODO SameDiff spaceToBatch has issues, see https://github.com/eclipse/deeplearning4j/issues/9019 + throw new UnsupportedOperationException("Can't convert SpaceToBatchLayer to SameDiff"); +// return sameDiff.cnn.spaceToBatch(layerInput, blocks, padding[0], padding[1]); + } @Override public void setNIn(InputType inputType, boolean override) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java index 53d9007be47b..85f7b672d279 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java @@ -26,6 +26,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -125,6 +127,19 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + org.nd4j.enums.DataFormat format; + if(dataFormat == CNN2DFormat.NCHW) + format = org.nd4j.enums.DataFormat.NCHW; + else if(dataFormat == CNN2DFormat.NHWC) + format = org.nd4j.enums.DataFormat.NHWC; + else + throw new UnsupportedOperationException("Unknown CNN data format " + dataFormat); + + return sameDiff.cnn.spaceToDepth(layerInput, blockSize, format); + } @Override public void setNIn(InputType inputType, boolean override) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 5d2a55994e80..3b082c6ac561 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -19,6 +19,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -28,6 +29,8 @@ import org.deeplearning4j.util.Convolution1DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -75,6 +78,16 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + layerInput = sameDiff.expandDims(layerInput, -1); + + SDVariable out = super.defineLayer(sameDiff, layerInput, mask, paramTable); + return sameDiff.squeeze(out, -1); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 67a2e804c1aa..579bda2b846e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; @@ -29,9 +30,12 @@ import org.deeplearning4j.util.Convolution3DUtils; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.learning.regularization.Regularization; @@ -132,6 +136,27 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + Pooling3DConfig poolingConfig = Pooling3DConfig.builder() + .kD(kernelSize[0]).kH(kernelSize[1]).kW(kernelSize[2]) + .sD(stride[0]).sH(stride[1]).sW(stride[2]) + .pD(padding[0]).pH(padding[1]).pW(padding[2]) + .dD(dilation[0]).dH(dilation[1]).dW(dilation[2]) + .isNCDHW(dataFormat == DataFormat.NCDHW) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .build(); + + if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.MAX){ + return sameDiff.cnn.maxPooling3d(layerInput, poolingConfig); + } else if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.AVG){ + return sameDiff.cnn.avgPooling3d(layerInput, poolingConfig); + } else { + throw new UnsupportedOperationException("Can't convert " + poolingType + " pooling layer to SameDiff, only MAX and AVG supported"); + } + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index a434d05bccb6..1f427af4fbb6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -29,12 +29,15 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; /** * Subsampling layer also referred to as pooling in convolution neural nets @@ -147,6 +150,28 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + Pooling2DConfig poolingConfig = Pooling2DConfig.builder() + .kH(kernelSize[0]).kW(kernelSize[1]) + .sH(stride[0]).sW(stride[1]) + .pH(padding[0]).pW(padding[1]) + .dH(dilation[0]).dW(dilation[1]) + .isNHWC(cnn2dDataFormat == CNN2DFormat.NHWC) + .isSameMode(convolutionMode == ConvolutionMode.Same) + .build(); + + if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.MAX){ + return sameDiff.cnn.maxPooling2d(layerInput, poolingConfig); + } else if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.AVG){ + return sameDiff.cnn.avgPooling2d(layerInput, poolingConfig); + } else { + throw new UnsupportedOperationException("Can't convert " + poolingType + " pooling layer to SameDiff, only MAX and AVG supported"); + } + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java index 4b39fa34d22a..a258486c0c18 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.ToString; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -28,6 +29,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -84,6 +87,12 @@ public Upsampling1D clone() { return clone; } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return sameDiff.squeeze(sameDiff.cnn.upsampling2d(sameDiff.expandDims(layerInput, -1), size[0], 1, true), -1); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index 0357c3e7bab9..5361e932f3cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -26,6 +26,8 @@ import org.deeplearning4j.nn.conf.serde.legacy.LegacyIntArrayDeserializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; @@ -89,6 +91,12 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return sameDiff.cnn.upsampling2d(layerInput, size[0], size[1], format == CNN2DFormat.NCHW); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java index 695212d89d3a..4808d33a9736 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java @@ -20,10 +20,13 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -91,6 +94,12 @@ public InputType getOutputType(int layerIndex, InputType inputType) { return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return sameDiff.cnn.upsampling3d(layerInput, dataFormat == DataFormat.NCDHW, size[0], size[1], size[2]); + } + @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { if (inputType == null) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java index a3345fde9761..3678d605688a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java @@ -27,12 +27,15 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.factory.Nd4j; /** * Zero padding 1D layer for convolutional neural networks. Allows padding to be done separately for top and bottom. @@ -82,6 +85,23 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + int padLeft = padding[0]; + int padRight = padding[1]; + + //TODO support data formats + int[][] fullPadding = new int[][]{ + {0, 0}, + {0, 0}, + {padLeft, padRight} + }; + + return sameDiff.nn.pad(layerInput, sameDiff.constant(Nd4j.createFromArray(fullPadding)), 0); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java index 8dfd594a6ace..a9578eaefbd3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java @@ -26,12 +26,15 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.factory.Nd4j; /** * Zero padding 3D layer for convolutional neural networks. Allows padding to be done separately for "left" and "right" @@ -70,6 +73,30 @@ public ParamInitializer initializer() { return EmptyParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + //TODO support data formats + int padLeftD = padding[0]; + int padRightD = padding[1]; + int padLeftH = padding[2]; + int padRightH = padding[3]; + int padLeftW = padding[4]; + int padRightW = padding[5]; + + int[][] fullPadding; + fullPadding = new int[][]{ + {0, 0}, + {0, 0}, + {padLeftD, padRightD}, + {padLeftH, padRightH}, + {padLeftW, padRightW} + }; + + return sameDiff.nn.pad(layerInput, sameDiff.constant(Nd4j.createFromArray(fullPadding)), 0); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java index ef92d2f8588a..d15a41b6c500 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java @@ -26,6 +26,8 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,6 +35,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.factory.Nd4j; /** * Zero padding layer for convolutional neural networks (2D CNNs). Allows padding to be done separately for @@ -82,6 +85,37 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + int padTop = padding[0]; + int padBottom = padding[1]; + int padLeft = padding[2]; + int padRight = padding[3]; + + int[][] fullPadding; + if(dataFormat == CNN2DFormat.NCHW){ + fullPadding = new int[][]{ + {0, 0}, + {0, 0}, + {padTop, padBottom}, + {padLeft, padRight} + }; + } else if(dataFormat == CNN2DFormat.NHWC) { + fullPadding = new int[][]{ + {0, 0}, + {padTop, padBottom}, + {padLeft, padRight}, + {0, 0} + }; + } else { + throw new UnsupportedOperationException("Unknown CNN data format " + dataFormat); + } + + return sameDiff.nn.pad(layerInput, sameDiff.constant(Nd4j.createFromArray(fullPadding)), 0); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { int[] hwd = ConvolutionUtils.getHWDFromInputType(inputType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java index b10b0e716b73..d0f04a95100e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java @@ -27,6 +27,9 @@ import org.deeplearning4j.nn.layers.convolution.Cropping1DLayer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -86,6 +89,19 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + private static Integer end(int idx){ + if(idx == 0) + return null; + else + return -idx; + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + return layerInput.get(SDIndex.all(), SDIndex.all(), SDIndex.interval(cropping[0], end(cropping[1]))); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index 3a13e2fc052d..1a98918142df 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -29,6 +29,9 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ConvolutionUtils; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -103,6 +106,25 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + private static Integer end(int idx){ + if(idx == 0) + return null; + else + return -idx; + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + if(dataFormat == CNN2DFormat.NCHW) { + return layerInput.get(SDIndex.all(), SDIndex.all(), SDIndex.interval(cropping[0], end(cropping[1])), SDIndex.interval(cropping[2], end(cropping[3]))); + } else if(dataFormat == CNN2DFormat.NHWC){ + return layerInput.get(SDIndex.all(), SDIndex.interval(cropping[0], end(cropping[1])), SDIndex.interval(cropping[2], end(cropping[3])), SDIndex.all()); + } else { + throw new UnsupportedOperationException("Unknown CNN data format " + dataFormat); + } + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { int[] hwd = ConvolutionUtils.getHWDFromInputType(inputType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java index 74710a4699e1..d1f9c8630529 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java @@ -27,6 +27,9 @@ import org.deeplearning4j.nn.layers.convolution.Cropping3DLayer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -95,6 +98,23 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + private static Integer end(int idx){ + if(idx == 0) + return null; + else + return -idx; + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + //TODO support different dataTypes + return layerInput.get(SDIndex.all(), SDIndex.all(), + SDIndex.interval(cropping[0], end(cropping[1])), + SDIndex.interval(cropping[2], end(cropping[3])), + SDIndex.interval(cropping[4], end(cropping[5]))); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java index ef88dc8b7ac2..1faf30e8269d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java @@ -26,6 +26,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.ElementWiseParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -83,6 +85,17 @@ public ParamInitializer initializer() { return ElementWiseParamInitializer.getInstance(); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable weight = paramTable.get(ElementWiseParamInitializer.WEIGHT_KEY); + SDVariable bias = paramTable.get(ElementWiseParamInitializer.BIAS_KEY); + + SDVariable out = layerInput.mul(weight).add(bias); + + return doActivation(out); + } + /** * This is a report of the estimated memory consumption for the given layer * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java index f72da09e592a..fcebe4c65fc3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java @@ -16,8 +16,10 @@ package org.deeplearning4j.nn.conf.layers.misc; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; @@ -29,12 +31,14 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.FrozenLayerParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.NameScope; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.shade.jackson.annotation.JsonProperty; -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import java.util.Collection; import java.util.List; @@ -102,6 +106,30 @@ public ParamInitializer initializer() { return FrozenLayerParamInitializer.getInstance(); } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + layer.transformParamsForSameDiff(params); + } + + /** + * Will freeze any params passed to it. + * @param sameDiff SameDiff instance + * @param layerInput Input to the layer + * @param mask Optional, maybe null. Mask to apply if supported + * @param paramTable Parameter table - keys and shapes as defined in the layer implementation class. + */ + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + for(SDVariable variable : paramTable.values()){ + variable.convertToConstant(); + } + NameScope underlyingScope = sameDiff.withNameScope("underlying"); + SDVariable output = layer.defineLayer(sameDiff, layerInput, mask, paramTable); + underlyingScope.close(); + return output; + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { return layer.getOutputType(layerIndex, inputType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java index 468c310329b7..9bd6e522cf0f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java @@ -16,7 +16,9 @@ package org.deeplearning4j.nn.conf.layers.misc; +import java.util.Map; import lombok.Data; +import lombok.NonNull; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -24,6 +26,8 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.params.FrozenLayerWithBackpropParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; @@ -83,6 +87,22 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return new org.deeplearning4j.nn.layers.FrozenLayerWithBackprop(underlying); } + /** + * Will freeze any params passed to it. + * @param sameDiff SameDiff instance + * @param layerInput Input to the layer + * @param mask Optional, maybe null. Mask to apply if supported + * @param paramTable Parameter table - keys and shapes as defined in the layer implementation class. + */ + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + for(SDVariable variable : paramTable.values()){ + variable.convertToConstant(); + } + return defineUnderlying(sameDiff, layerInput, mask, paramTable); + } + @Override public ParamInitializer initializer() { return FrozenLayerWithBackpropParamInitializer.getInstance(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java index e7f252c4ba50..1a6b2ed4c5bb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java @@ -26,6 +26,8 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -79,6 +81,24 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, return ret; } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + layerInput = sameDiff.expandDims(layerInput, -1); // [batch, size, 1] + SDVariable out; + out = sameDiff.tile(layerInput, 1, 1, n); // [batch, size, n] + + //noinspection StatementWithEmptyBody + if(dataFormat == RNNFormat.NCW){ + } else if(dataFormat == RNNFormat.NWC) { + out = out.permute(0, 2, 1); + } else { + throw new UnsupportedOperationException("Unknown RNN data format " + dataFormat); + } + + return doActivation(out); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.FF) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java index 792e5633b36c..24a79ca759f7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.conf.layers.recurrent; +import java.util.HashMap; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; @@ -32,6 +33,8 @@ import org.deeplearning4j.nn.params.BidirectionalParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.TimeSeriesUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; @@ -43,7 +46,6 @@ import java.util.Map; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -import static org.nd4j.linalg.indexing.NDArrayIndex.point; /** * Bidirectional is a "wrapper" layer: it wraps any uni-directional RNN layer to make it bidirectional.
Note that @@ -80,6 +82,7 @@ public enum Mode { private transient BidirectionalParamInitializer initializer; private Bidirectional(Bidirectional.Builder builder) { + //TODO builder params aren't used? super(builder); } @@ -110,6 +113,92 @@ public Bidirectional(@NonNull Mode mode, @NonNull Layer layer) { this.mode = mode; } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + Map fwdParams = new HashMap<>(); + Map bwdParams = new HashMap<>(); + + for(String key : params.keySet()){ + if(key.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){ + fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.FORWARD_PREFIX, ""), params.get(key)); + } else if(key.startsWith(BidirectionalParamInitializer.BACKWARD_PREFIX)){ + fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.BACKWARD_PREFIX, ""), params.get(key)); + } + } + + fwd.transformParamsForSameDiff(fwdParams); + bwd.transformParamsForSameDiff(bwdParams); + + params.clear(); + for(Map.Entry entry : fwdParams.entrySet()){ + params.put(BidirectionalParamInitializer.FORWARD_PREFIX + entry.getKey(), entry.getValue()); + } + for(Map.Entry entry : bwdParams.entrySet()){ + params.put(BidirectionalParamInitializer.BACKWARD_PREFIX + entry.getKey(), entry.getValue()); + } + } + + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + + Map fwdParams = new HashMap<>(); + Map bwdParams = new HashMap<>(); + + for(String key : paramTable.keySet()){ + if(key.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){ + fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.FORWARD_PREFIX, ""), paramTable.get(key)); + } else if(key.startsWith(BidirectionalParamInitializer.BACKWARD_PREFIX)){ + fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.BACKWARD_PREFIX, ""), paramTable.get(key)); + } + } + + if(fwd instanceof BaseRecurrentLayer){ + + try{ + return ((BaseRecurrentLayer) fwd).defineBidirectional(sameDiff, layerInput, paramTable, mask, mode); + } catch (UnsupportedOperationException e) { + + + SDVariable fwdOut = ((BaseRecurrentLayer) fwd) + .defineLayer(sameDiff, layerInput, fwdParams, mask, false); + SDVariable bwdOut = ((BaseRecurrentLayer) fwd) + .defineLayer(sameDiff, layerInput, bwdParams, mask, true); + + + bwdOut = sameDiff.reverse(bwdOut, ((BaseRecurrentLayer) fwd).getRnnDataFormat() == RNNFormat.NCW ? 2 : 1); + + if(mode == Mode.CONCAT) { + if(((BaseRecurrentLayer) fwd).getRnnDataFormat() == RNNFormat.NCW) + return sameDiff.concat(1, fwdOut, bwdOut); + else + return sameDiff.concat(2, fwdOut, bwdOut); + } else if(mode == Mode.ADD) + return fwdOut.add(bwdOut); + else if(mode == Mode.AVERAGE) + return fwdOut.add(bwdOut).div(2); + else if(mode == Mode.MUL) + return fwdOut.mul(bwdOut); + else + throw new UnsupportedOperationException("Unknown bidirectional mode " + mode); + } + } else if(fwd instanceof LastTimeStep){ + SDVariable fwdOut = fwd.defineLayer(sameDiff, layerInput, mask, fwdParams); + SDVariable bwdOut = bwd.defineLayer(sameDiff, layerInput, mask, bwdParams); + if(mode == Mode.CONCAT) { + return sameDiff.concat(1, fwdOut, bwdOut); + } else if(mode == Mode.ADD) + return fwdOut.add(bwdOut); + else if(mode == Mode.AVERAGE) + return fwdOut.add(bwdOut).div(2); + else if(mode == Mode.MUL) + return fwdOut.mul(bwdOut); + else + throw new UnsupportedOperationException("Unknown bidirectional mode " + mode); + } else + throw new UnsupportedOperationException("Bidirectional toSameDiff doesn't support layer " + fwd.getClass().getSimpleName()); + } + public long getNOut() { if (this.fwd instanceof LastTimeStep) { return ((FeedForwardLayer) ((LastTimeStep) this.fwd).getUnderlying()).getNOut(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java index 52c048472f27..5ae8f6844b11 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java @@ -16,12 +16,17 @@ package org.deeplearning4j.nn.conf.layers.recurrent; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -60,6 +65,13 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, initializeParams, networkDataType)); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable underlyingOutput = defineUnderlying(sameDiff, layerInput, mask, paramTable); + return underlyingOutput.get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java index 7bc91c17eb00..4bce1dd45bd0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java @@ -19,21 +19,27 @@ import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.LayerValidation; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; /** * Simple RNN - aka "vanilla" RNN is the simplest type of recurrent neural network layer. It implements {@code out_t = diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java index 5489ccc78d0e..8887e7ddbff8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -1,5 +1,6 @@ package org.deeplearning4j.nn.conf.layers.recurrent; +import java.util.Map; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NonNull; @@ -11,8 +12,12 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.shade.jackson.annotation.JsonProperty; import java.util.Collection; @@ -53,6 +58,36 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, initializeParams, networkDataType), rnnDataFormat); } + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + SDVariable originalShape = layerInput.shape(); + SDVariable batch = originalShape.get(SDIndex.point(0)); + SDVariable sequenceLength; + + SDVariable neg1 = sameDiff.constant(Nd4j.scalar(batch.dataType(), -1)); + + if(rnnDataFormat == RNNFormat.NCW) { + sequenceLength = originalShape.get(SDIndex.point(2)); + layerInput = layerInput.permute(0, 2, 1); + } else if(rnnDataFormat == RNNFormat.NWC) + sequenceLength = originalShape.get(SDIndex.point(1)); + else + throw new UnsupportedOperationException("Unknown RNN data format " + rnnDataFormat); + + SDVariable distributedShape = sameDiff.concat(0, batch.mul(sequenceLength), neg1); + SDVariable distributedInput = layerInput.reshape(distributedShape); + + SDVariable distributedOutput = defineUnderlying(sameDiff, distributedInput, mask, paramTable); + + SDVariable temp = distributedOutput.reshape(sameDiff.concat(0, batch, sequenceLength, neg1)); + + if(rnnDataFormat == RNNFormat.NCW) + return temp.permute(0, 2, 1); + else + return temp; + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType.getType() != InputType.Type.RNN) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java index 7271d59efaad..d4800dc174b8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java @@ -42,7 +42,8 @@ public abstract class SameDiffLambdaLayer extends SameDiffLayer { public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput); @Override - public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable) { return defineLayer(sameDiff, layerInput); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java index f290a09f36f1..5452471363f0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java @@ -74,12 +74,12 @@ protected SameDiffLayer() { * * @param sameDiff SameDiff instance * @param layerInput Input to the layer - * @param paramTable Parameter table - keys as defined by {@link #defineParameters(SDLayerParams)} * @param mask Optional, maybe null. Mask to apply if supported + * @param paramTable Parameter table - keys as defined by {@link #defineParameters(SDLayerParams)} * @return The final layer variable corresponding to the activations/output from the forward pass */ - public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, - Map paramTable, SDVariable mask); + public abstract SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable); /** * @see Layer#feedForwardMaskArray(INDArray, MaskState, int) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java index d6c4892f37fd..6e779485c081 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.conf.layers.samediff; +import lombok.NonNull; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.autodiff.samediff.SDVariable; @@ -63,9 +64,15 @@ protected SameDiffOutputLayer() { * @param paramTable Parameter table - keys as defined by {@link #defineParameters(SDLayerParams)} * @return The final layer variable corresponding to the score/loss during forward pass. This must be a single scalar value. */ - public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, + public abstract SDVariable defineLayerAndLoss(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable); + @Override + public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, + SDVariable mask, @NonNull Map paramTable) { + throw new IllegalStateException("SameDiffOutputLayers should be defined using the define method using labels"); + } + /** * Output layers should terminate in a single scalar value (i.e., a score) - however, sometimes the output activations * (such as softmax probabilities) need to be returned. When this is the case, we need to know the name of the diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java index fd9c64ba6643..e0d89a98eb00 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java @@ -16,9 +16,11 @@ package org.deeplearning4j.nn.conf.layers.util; +import java.util.Map; import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -27,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java index c11f618dafee..ee987f2e6ef3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java @@ -16,7 +16,9 @@ package org.deeplearning4j.nn.conf.layers.wrapper; +import java.util.Map; import lombok.Data; +import lombok.NonNull; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; @@ -24,6 +26,10 @@ import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.WrapperLayerParamInitializer; +import org.nd4j.autodiff.samediff.NameScope; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; @@ -54,6 +60,19 @@ public ParamInitializer initializer() { return WrapperLayerParamInitializer.getInstance(); } + @Override + public void transformParamsForSameDiff(@NonNull Map params) { + underlying.transformParamsForSameDiff(params); + } + + protected SDVariable defineUnderlying(SameDiff sameDiff, SDVariable layerInput, SDVariable mask, + Map paramTable){ + NameScope underlyingScope = sameDiff.withNameScope("underlying"); + SDVariable output = underlying.defineLayer(sameDiff, layerInput, mask, paramTable); + underlyingScope.close(); + return output; + } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { return underlying.getOutputType(layerIndex, inputType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java index 539289ecafbd..2fd04496064c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java @@ -24,6 +24,8 @@ import org.deeplearning4j.nn.conf.layers.LayerValidation; import org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -124,6 +126,25 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection paramTable) { + SDVariable w = paramTable.get(OCNNParamInitializer.W_KEY); + SDVariable v = paramTable.get(OCNNParamInitializer.V_KEY); + + SDVariable wFlat = w.reshape(sameDiff.concat(0, sameDiff.sizeAt(w, 0), sameDiff.constant(-1))); + + SDVariable first = layerInput.mul(v); + SDVariable act2d = doActivation(first); + return act2d.mul(wFlat); //TODO DL4J implementation sets labels to the output as well, will this work here? probably not + } + + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, + boolean average) { + return lossFn.defineLoss(sameDiff, input, input, average); + } + @Override public long getNOut() { //we don't change number of outputs here diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java index 83f1097b7820..5203cca260bd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java @@ -16,8 +16,11 @@ package org.deeplearning4j.nn.conf.preprocessor; +import lombok.NonNull; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -43,4 +46,8 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt //Default: pass-through, unmodified return new Pair<>(maskArray, currentMaskState); } + + public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input){ + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java index f6ba7af7b344..0041fadfbc07 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java @@ -49,7 +49,7 @@ * @see FeedForwardToCnn3DPreProcessor for opposite case (i.e., DenseLayer -> CNN3D) */ @Data -public class Cnn3DToFeedForwardPreProcessor implements InputPreProcessor { +public class Cnn3DToFeedForwardPreProcessor extends BaseInputPreProcessor { protected long inputDepth; protected long inputHeight; protected long inputWidth; @@ -141,16 +141,6 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr } - @Override - public Cnn3DToFeedForwardPreProcessor clone() { - try { - Cnn3DToFeedForwardPreProcessor clone = (Cnn3DToFeedForwardPreProcessor) super.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } - } - @Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java index 1a7e3928b9be..0779d4c1786e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java @@ -17,11 +17,14 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.Data; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.common.primitives.Pair; @@ -51,7 +54,7 @@ * @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc) */ @Data -public class CnnToFeedForwardPreProcessor implements InputPreProcessor { +public class CnnToFeedForwardPreProcessor extends BaseInputPreProcessor { protected long inputHeight; protected long inputWidth; protected long numChannels; @@ -157,16 +160,6 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace } - @Override - public CnnToFeedForwardPreProcessor clone() { - try { - CnnToFeedForwardPreProcessor clone = (CnnToFeedForwardPreProcessor) super.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } - } - @Override public InputType getOutputType(InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.CNN) { @@ -195,4 +188,9 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt return new Pair<>(maskArray.reshape(maskArray.ordering(), maskArray.size(0), maskArray.size(1)), currentMaskState); } + + @Override + public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return input.reshape(-1, numChannels * inputHeight * inputWidth); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java index 6f18e70e4e88..de4b5c5a2338 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java @@ -48,7 +48,7 @@ */ @Data @EqualsAndHashCode(exclude = {"product"}) -public class CnnToRnnPreProcessor implements InputPreProcessor { +public class CnnToRnnPreProcessor extends BaseInputPreProcessor { private long inputHeight; private long inputWidth; private long numChannels; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java index 55f0e6b12e95..bd2fa5bdbbca 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java @@ -18,10 +18,13 @@ import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NonNull; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.workspace.ArrayType; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -90,4 +93,11 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt } return new Pair<>(maskArray, currentMaskState); } + + @Override + public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + for(InputPreProcessor preProcessor : inputPreProcessors) + input = preProcessor.definePreProcess(sameDiff, input); + return input; + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java index 305ace53012d..461e8de77336 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java @@ -48,7 +48,7 @@ */ @Data @EqualsAndHashCode(exclude = {"shape"}) -public class FeedForwardToCnn3DPreProcessor implements InputPreProcessor { +public class FeedForwardToCnn3DPreProcessor extends BaseInputPreProcessor { private int inputDepth; private int inputHeight; private int inputWidth; @@ -126,14 +126,10 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr @Override public FeedForwardToCnn3DPreProcessor clone() { - try { - FeedForwardToCnn3DPreProcessor clone = (FeedForwardToCnn3DPreProcessor) super.clone(); - if (clone.shape != null) - clone.shape = clone.shape.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } + FeedForwardToCnn3DPreProcessor clone = (FeedForwardToCnn3DPreProcessor) super.clone(); + if (clone.shape != null) + clone.shape = clone.shape.clone(); + return clone; } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java index 817d538489e3..c7ef7e079493 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java @@ -20,6 +20,8 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.common.primitives.Pair; @@ -48,7 +50,7 @@ */ @Data @EqualsAndHashCode(exclude = {"shape"}) -public class FeedForwardToCnnPreProcessor implements InputPreProcessor { +public class FeedForwardToCnnPreProcessor extends BaseInputPreProcessor { private long inputHeight; private long inputWidth; private long numChannels; @@ -115,14 +117,10 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr @Override public FeedForwardToCnnPreProcessor clone() { - try { - FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone(); - if (clone.shape != null) - clone.shape = clone.shape.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } + FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone(); + if (clone.shape != null) + clone.shape = clone.shape.clone(); + return clone; } @Override @@ -167,4 +165,13 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt return new Pair<>(maskArray, currentMaskState); } + @Override + public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + //TODO Assuming shape of input is correct, it would be better to check & throw exception here, but needs offline shape inference + + if(numChannels == -1) + throw new IllegalStateException("Can't convert when numChannels isn't explicitly specified"); + //TODO get batch size. Needs offline shape inference + return input.reshape(-1, numChannels, inputHeight, inputWidth); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java index e6bca1bed18a..34b5417248f2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java @@ -48,7 +48,7 @@ */ @Data @NoArgsConstructor -public class FeedForwardToRnnPreProcessor implements InputPreProcessor { +public class FeedForwardToRnnPreProcessor extends BaseInputPreProcessor { private RNNFormat rnnDataFormat = RNNFormat.NCW; public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java index 57487aae7267..27e077400116 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java @@ -48,7 +48,7 @@ */ @Data @EqualsAndHashCode(exclude = {"product"}) -public class RnnToCnnPreProcessor implements InputPreProcessor { +public class RnnToCnnPreProcessor extends BaseInputPreProcessor { private int inputHeight; private int inputWidth; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java index d4355e38b651..507e436b80aa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java @@ -51,7 +51,7 @@ @Data @Slf4j @NoArgsConstructor -public class RnnToFeedForwardPreProcessor implements InputPreProcessor { +public class RnnToFeedForwardPreProcessor extends BaseInputPreProcessor { private RNNFormat rnnDataFormat = RNNFormat.NCW; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 2f7bd45eeecf..990b86f87ab1 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -24,7 +24,16 @@ import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.LayerWithLoss; +import org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer; +import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater; +import org.deeplearning4j.util.ToSameDiffUtils; import org.nd4j.adapters.OutputAdapter; +import org.nd4j.autodiff.samediff.NameScope; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.api.*; @@ -91,6 +100,9 @@ import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Triple; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.NoOp; +import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.workspace.ND4JWorkspaceException; import org.nd4j.linalg.workspace.WorkspaceUtils; @@ -747,6 +759,208 @@ public void init(INDArray parameters, boolean cloneParametersArray) { initCalled = true; } + /** + * + * Create the MultiLayerNetwork in a SameDiff instance. + * + * The input and lables placeholders are created with names "input" and "labels", respectively. + * Output and loss variables are set on the SameDiff instance and can be gotten from it. + * + * @param sameDiff The SameDiff instance to create the model in + * @param inputTypes The types of the inputs. + * @param useView whether to directly use the (view) weights in the SDVariables, or create new ones. + * Using them saves an initialization (of every weight), but may cause issues with multi-gpu setups. + * @param skipErrors Whether to ignore updater or regularization configuration if they aren't the same on all layers. + * @return The {@link org.nd4j.autodiff.samediff.TrainingConfig} if training is setup (the last layer is an BaseOutputLayer), or null if not. + */ + public org.nd4j.autodiff.samediff.TrainingConfig toSameDiff(@NonNull SameDiff sameDiff, @NonNull Map inputTypes, boolean useView, boolean skipErrors) { + + if (!initCalled) + init(); + + Preconditions.checkArgument(inputTypes.keySet().equals(new HashSet<>(configuration.getNetworkInputs())), "Must specify input types for all inputs. Expected %s, but got %s.", + inputTypes.keySet(), configuration.getNetworkInputs()); + + InputType[] inputVertTypes = new InputType[inputTypes.size()]; + int j = 0; + for(String inputName : configuration.getNetworkInputs()){ + inputVertTypes[j] = inputTypes.get(inputName); + j++; + } + + Map outputTypes = configuration.getLayerActivationTypes(true, inputVertTypes); + + Map activations = new HashMap<>(); + for(Map.Entry input : inputTypes.entrySet()){ + activations.put(input.getKey(), sameDiff.placeHolder(input.getKey(), configuration.getDataType(), input.getValue().getShape(true))); + } + + Map sdOutputLabels = new HashMap<>(); + + for (int i : topologicalOrder) { + GraphVertex vertex = vertices[i]; + String name = vertex.getVertexName(); + + if(vertex instanceof InputVertex) + continue; + + NameScope layerScope = sameDiff.withNameScope(name); + + Map paramTable = ToSameDiffUtils.defineParams(sameDiff, vertex, useView); + + SDVariable[] inputs = new SDVariable[vertex.getNumInputArrays()]; + j = 0; + for(String inputVertex : configuration.getVertexInputs().get(name)){ + inputs[j] = activations.get(inputVertex); + j++; + } + + SDVariable output; + if(vertex.hasLayer() && vertex.getLayer() instanceof SameDiffOutputLayer){ + String inputName = configuration.getVertexInputs().get(name).get(0); + SDVariable labels = null; + if(((SameDiffOutputLayer) vertex.getLayer()).needsLabels()){ + labels = sameDiff + .placeHolder("labels", configuration.getDataType(), outputTypes.get(inputName).getShape(true)); + } + + SDVariable input = activations.get(inputName); + output = ((SameDiffOutputLayer) vertex.getLayer()).layerConf().defineLayerAndLoss(sameDiff, input, labels, paramTable); + sdOutputLabels.put(name, labels); + } else { + output = vertex.defineVertex(sameDiff, inputs, null, paramTable); + } + + activations.put(name, output); + + layerScope.close(); + } + + List sdOutputs = new ArrayList<>(); + for(String vertex : configuration.getNetworkOutputs()){ + sdOutputs.add(activations.get(vertex).name()); + } + + sameDiff.setOutputs(sdOutputs); + + List losses = new ArrayList<>(); + List allLabels = new ArrayList<>(); + + for(String output : configuration.getNetworkOutputs()){ + GraphVertex vertex = verticesMap.get(output); + SDVariable loss; + SDVariable labels; + if(vertex.hasLayer() && vertex.getLayer() instanceof SameDiffOutputLayer) { + loss = activations.get(vertex.getVertexName()); + labels = sdOutputLabels.get(vertex.getVertexName()); + + } else if(vertex.hasLayer() && vertex.getLayer() instanceof IOutputLayer && vertex.getLayer().conf().getLayer() instanceof LayerWithLoss){ + LayerWithLoss lossLayer = (LayerWithLoss) vertex.getLayer().conf().getLayer(); + SDVariable input = activations.get(output); + labels = null; + + NameScope vertexScope = sameDiff.withNameScope(vertex.getVertexName()); + + if(((IOutputLayer) vertex.getLayer()).needsLabels()) { + labels = sameDiff + .placeHolder("labels", configuration.getDataType(), outputTypes.get(output).getShape(true)); + } + NameScope lossScope = sameDiff.withNameScope("loss"); + + loss = lossLayer.defineLoss(sameDiff, input, labels, conf().isMiniBatch()); + lossScope.close(); + + loss.rename("loss"); + + vertexScope.close(); + + } else { + continue; + } + + losses.add(loss.name()); + if(labels != null) + allLabels.add(labels.name()); + } + + if(losses.size() > 0){ + + IUpdater iUpdater = ToSameDiffUtils.getUpdater(layers, skipErrors); + List regularizations = ToSameDiffUtils.getRegularizations(layers, skipErrors); + + String[] lossArr = losses.toArray(new String[0]); + sameDiff.setLossVariables(lossArr); + + TrainingConfig.Builder tcBuilder = org.nd4j.autodiff.samediff.TrainingConfig.builder() + .minimize(lossArr) + .minimize(conf().isMinimize()) + .dataSetFeatureMapping(configuration.getNetworkInputs().toArray(new String[0])); + + if(regularizations != null) + tcBuilder.regularization(regularizations); + + if(iUpdater != null) + tcBuilder.updater(iUpdater.clone()); + else + tcBuilder.updater(new NoOp()); + + if(allLabels.size() == 0) + tcBuilder.markLabelsUnused(); + else + tcBuilder.dataSetLabelMapping(allLabels.toArray(new String[0])); + + org.nd4j.autodiff.samediff.TrainingConfig trainingConfig = tcBuilder.build(); + + trainingConfig.setIterationCount(getIterationCount()); + trainingConfig.setEpochCount(getEpochCount()); + + sameDiff.setTrainingConfig(trainingConfig); + + if(iUpdater != null) { + Updater updater = getUpdater(); + + if(updater instanceof BaseMultiLayerUpdater){ + ToSameDiffUtils.copyUpdaterState(sameDiff, (BaseMultiLayerUpdater) updater, null); + } else { + if(skipErrors) + log.warn("Unsupported updater type {}, not copying updater state to SameDiff", updater.getClass().getSimpleName()); + else + throw new IllegalStateException("Unsupported updater type " + updater.getClass().getSimpleName() + ", could not updater state to SameDiff"); + } + + + } + + return trainingConfig; + } + + return null; + } + + + /** + * See {@link #toSameDiff(SameDiff, Map, boolean, boolean)}. {@code useView} and {@code skipErrors} are true. + */ + public TrainingConfig toSameDiff(@NonNull SameDiff sameDiff, @NonNull Map inputTypes){ + return toSameDiff(sameDiff, inputTypes, true, true); + } + + /** + * See {@link #toSameDiff(SameDiff, Map, boolean, boolean)}. + */ + public SameDiff toSameDiff(@NonNull Map inputTypes, boolean useView, boolean skipErrors){ + SameDiff sameDiff = SameDiff.create(); + toSameDiff(sameDiff, inputTypes, useView, skipErrors); + return sameDiff; + } + + /** + * See {@link #toSameDiff(SameDiff, Map, boolean, boolean)}. {@code useView} and {@code skipErrors} are true. + */ + public SameDiff toSameDiff(@NonNull Map inputTypes){ + return toSameDiff(inputTypes, true, true); + } + /** * This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers. * As a general rule, this shouldn't ever need to be called manually when doing training via fit(DataSet), fit(DataSetIterator) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java index 555c64f94d03..c2b13daab697 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java @@ -18,10 +18,13 @@ import lombok.Data; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -76,6 +79,18 @@ protected BaseGraphVertex(ComputationGraph graph, String name, int vertexIndex, this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)]; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } + + @Override + public void transformParamsForSameDiff(@NonNull Map params){ + if(hasLayer()) + getLayer().conf().getLayer().transformParamsForSameDiff(params); + } + @Override public String getVertexName() { return vertexName; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java index de6b18428335..3e8128f3cc8b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java @@ -16,11 +16,14 @@ package org.deeplearning4j.nn.graph.vertex; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -41,6 +44,17 @@ protected BaseWrapperVertex(GraphVertex underlying){ this.underlying = underlying; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } + + @Override + public void transformParamsForSameDiff(@NonNull Map params){ + underlying.transformParamsForSameDiff(params); + } + @Override public String getVertexName() { return underlying.getVertexName(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java index a2477f9407c2..0a5610950658 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java @@ -16,10 +16,13 @@ package org.deeplearning4j.nn.graph.vertex; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.Trainable; import org.deeplearning4j.nn.gradient.Gradient; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -92,6 +95,33 @@ public interface GraphVertex extends Trainable, Serializable { /** Get the Layer (if any). Returns null if {@link #hasLayer()} == false */ Layer getLayer(); + /** + * Define the vertex for conversion to {@link SameDiff}.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} + * like the default implementation in {@link BaseGraphVertex}. + * + * @param sameDiff The {@link SameDiff} instance to define in. + * @param inputs The inputs to the vertex, in the same order as {@link #getInputVertices()}. + * @param mask The mask. May be null. + * @param paramTable The parameters for the vertex. Keys will be the same as {@link #paramTable(boolean)}. + * @return The output of the vertex. + */ + SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, SDVariable mask, + @NonNull Map paramTable); + + /** + * Do any necessary transforms to parameters (weights, biases, etc) before making SDVariables out of them. + * Useful for things like changing the dimension order or squeezing. + * + * Adding or removing parameters is supported.
+ * + * Should throw a {@link UnsupportedOperationException} if conversion of this layer configuration isn't + * supported and it will cause an error when transforming weights. + * + * @param params The parameter. + */ + void transformParamsForSameDiff(@NonNull Map params); + /** Set the input activations. * @param inputNumber Must be in range 0 to {@link #getNumInputArrays()}-1 * @param input The input array diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java index 5018dbe71829..1a18b07eae16 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -76,6 +80,36 @@ public Layer getLayer() { return null; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + if(inputs.length == 1) + return inputs[0]; + + if (op == Op.Subtract && inputs.length != 2) + throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs"); + + SDVariable acc = inputs[0]; + for(int i = 1 ; i < inputs.length ; i++){ + SDVariable next = inputs[i]; + if(op == Op.Add) + acc = acc.add(next); + else if(op == Op.Subtract) + acc = acc.sub(next); + else if(op == Op.Product) + acc = acc.mul(next); + else if(op == Op.Average) + acc = acc.add(next); + else if(op == Op.Max) + acc = sameDiff.math.max(acc, next); + } + + if(op == Op.Average) + acc = acc.div(inputs.length); + + return acc; + } + @Override public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { if (!canDoForward()) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java index 955ba8aba274..07e6648bf6a9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java @@ -16,16 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; -import lombok.AllArgsConstructor; +import java.util.Map; import lombok.EqualsAndHashCode; +import lombok.NonNull; import org.deeplearning4j.nn.api.TrainingConfig; -import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.misc.DummyConfig; import org.deeplearning4j.nn.graph.vertex.BaseWrapperVertex; import org.deeplearning4j.nn.graph.vertex.GraphVertex; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.learning.config.NoOp; +import org.nd4j.autodiff.samediff.NameScope; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; /** * FrozenVertex is used for the purposes of transfer learning @@ -48,4 +48,16 @@ public TrainingConfig getConfig(){ } return config; } + + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + for(SDVariable variable : paramTable.values()){ + variable.convertToConstant(); + } + NameScope underlyingScope = sameDiff.withNameScope("underlying"); + SDVariable output = underlying.defineVertex(sameDiff, inputs, mask, paramTable); + underlyingScope.close(); + return output; + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java index 32e67134514a..de92a3249357 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -37,6 +41,12 @@ public InputVertex(ComputationGraph graph, String name, int vertexIndex, VertexI super(graph, name, vertexIndex, null, outputVertices, dataType); } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + throw new IllegalStateException("InputVertices should never be manually converted to SameDiff"); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java index 894c026eda24..37c29ecf41eb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -59,6 +63,17 @@ public L2NormalizeVertex(ComputationGraph graph, String name, int vertexIndex, V this.eps = eps; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + + if(dimension == null || dimension.length < 1) + throw new IllegalStateException("Dimension must be set for toSameDiff conversion."); + + SDVariable factor = sameDiff.max(inputs[0].norm2(dimension), sameDiff.constant(Nd4j.scalar(inputs[0].dataType(), eps))); + return inputs[0].div(factor); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java index 20394880caab..7755166fc3cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -54,6 +58,16 @@ public L2Vertex(ComputationGraph graph, String name, int vertexIndex, VertexIndi this.eps = eps; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + SDVariable temp = inputs[0].sub(inputs[1]); + temp = temp.mul(temp); + temp = temp.reshape(sameDiff.concat(0, sameDiff.sizeAt(temp, 0), sameDiff.constant(Nd4j.scalar(DataType.INT64, -1)))) + .sum(true, 1); + return sameDiff.math.sqrt(temp); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java index d803808dff42..357dc6259de0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java @@ -18,6 +18,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.TrainingConfig; @@ -31,6 +32,9 @@ import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; +import org.nd4j.autodiff.samediff.NameScope; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -76,6 +80,28 @@ public LayerVertex(ComputationGraph graph, String name, int vertexIndex, VertexI this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)]; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + org.deeplearning4j.nn.conf.layers.Layer layerConf = layer.conf().getLayer(); + + InputPreProcessor preProcessor = getLayerPreProcessor(); + + SDVariable input = inputs[0]; + + if(preProcessor != null){ + NameScope preProcessorScope = sameDiff.withNameScope("inputPreprocessor"); + input = preProcessor.definePreProcess(sameDiff, input); + preProcessorScope.close(); + } + + if(layerConf.getIDropout() != null){ + input = layerConf.getIDropout().defineDropout(sameDiff, input); + } + + return layerConf.defineLayer(sameDiff, input, null, paramTable); + } + @Override public boolean hasLayer() { return true; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index 702767e8a573..787f2498b951 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -16,6 +16,8 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; @@ -23,6 +25,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -61,6 +65,15 @@ public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexI this.mergeAxis = mergeAxis; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + if(inputs.length == 1) + return inputs[0]; + + return sameDiff.concat(mergeAxis, inputs); + } + @Override public String toString() { return "MergeVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\")"; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java index 962b020817cb..fec76a507640 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java @@ -16,12 +16,17 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; @@ -49,6 +54,16 @@ public PoolHelperVertex(ComputationGraph graph, String name, int vertexIndex, Ve super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + + if (inputs.length > 1) + throw new IllegalStateException("PoolHelper vertex requires a single input."); + + return inputs[0].get(SDIndex.all(), SDIndex.all(), SDIndex.interval(1, -1), SDIndex.interval(1, -1)); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java index d8fe856174e1..c493c84034d7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java @@ -16,6 +16,8 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; @@ -23,6 +25,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -46,6 +50,12 @@ public PreprocessorVertex(ComputationGraph graph, String name, int vertexIndex, this.preProcessor = preProcessor; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + return preProcessor.definePreProcess(sameDiff, inputs[0]); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java index 39fcac462cdb..f69b64ba7ed9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java @@ -16,6 +16,8 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; @@ -24,6 +26,8 @@ import org.deeplearning4j.nn.graph.vertex.VertexIndices; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -54,6 +58,16 @@ public ReshapeVertex(ComputationGraph graph, String name, int vertexIndex, Verte this.maskShape = maskShape; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + + if (inputs.length > 1) + throw new IllegalStateException("Reshape vertex requires a single input."); + + return inputs[0].reshape(newShape); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java index c4fa89239b68..924e41901393 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -50,6 +54,17 @@ public ScaleVertex(ComputationGraph graph, String name, int vertexIndex, VertexI this.scaleFactor = scaleFactor; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + + if (inputs.length > 1) + throw new IllegalArgumentException( + "ScaleVertex (name " + vertexName + " idx " + vertexIndex + ") only supports 1 input."); + + return inputs[0].mul(scaleFactor); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java index d9f5c78de6f9..74589f3a65b6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -59,6 +63,17 @@ public ShiftVertex(ComputationGraph graph, String name, int vertexIndex, VertexI this.shiftFactor = shiftFactor; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + + if (inputs.length > 1) + throw new IllegalArgumentException( + "ShiftVertex (name " + vertexName + " idx " + vertexIndex + ") only supports 1 input."); + + return inputs[0].add(shiftFactor); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java index 3be9d6895581..b330fef451b8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java @@ -16,6 +16,8 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; @@ -23,6 +25,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java index db44492935f8..e8af7a29caf7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java index 9eb4151c8193..b60e5175676d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java @@ -16,14 +16,20 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; +import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.common.primitives.Pair; @@ -59,6 +65,13 @@ public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, Verte this.stackSize = stackSize; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + //TODO no way to calculate step as an int or get with a SDVariable + return super.defineVertex(sameDiff, inputs, mask, paramTable); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java index 8b3f2fba0b12..126cd8b64871 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java @@ -16,6 +16,8 @@ package org.deeplearning4j.nn.graph.vertex.impl.rnn; +import java.util.Map; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; @@ -23,6 +25,8 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java index 75ce3be3b491..9dfbf302f27d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java @@ -16,6 +16,8 @@ package org.deeplearning4j.nn.graph.vertex.impl.rnn; +import java.util.Map; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; @@ -23,6 +25,9 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; import org.deeplearning4j.nn.graph.vertex.VertexIndices; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -65,6 +70,12 @@ public LastTimeStepVertex(ComputationGraph graph, String name, int vertexIndex, + "of network inputs (" + graph.getConfiguration().getNetworkInputs() + ")"); } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + return inputs[0].get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1)); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java index 0d75119de0d4..8ece4f90b87a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java @@ -16,12 +16,16 @@ package org.deeplearning4j.nn.graph.vertex.impl.rnn; +import java.util.Map; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -63,6 +67,12 @@ public ReverseTimeSeriesVertex(ComputationGraph graph, String name, int vertexIn } } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + return sameDiff.reverse(inputs[0], 2); + } + @Override public boolean hasLayer() { return false; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java index c1c92d3a7322..016ec73a4687 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java @@ -75,7 +75,12 @@ public INDArray activate(boolean training, LayerWorkspaceMgr mgr) { //dup required: need to keep original input for backprop in = mgr.dup(ArrayType.ACTIVATIONS, input, input.ordering()); } else { - in = mgr.leverageTo(ArrayType.ACTIVATIONS, input); + if(mgr.isScopedOut(ArrayType.ACTIVATIONS) && !input.isAttached()) { + //Edge case: input and output are both not in workspaces - dup to avoid inplace modification + in = mgr.dup(ArrayType.ACTIVATIONS, input); + } else { + in = mgr.leverageTo(ArrayType.ACTIVATIONS, input); + } } return layerConf().getActivationFn().getActivation(in, training); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java index a10bb33f3333..fcbe633216a0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java @@ -354,4 +354,8 @@ public boolean isPretrainLayer() { public boolean hasBias() { return layerConf().hasBias(); } + + public ILossFunction getLossFn() { + return layerConf().getLossFn(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java index 7180ff446d7b..b9ad155a1a0c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java @@ -337,4 +337,8 @@ protected INDArray getLabels2d() { return labels; } + public ILossFunction getLossFn() { + return layerConf().getLossFn(); + } + } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java index 212161a9e02b..1130ccdc7848 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java @@ -218,6 +218,15 @@ public boolean needsLabels() { @Override public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) { + + assertInputSet(false); + if (input.rank() != 5) + throw new UnsupportedOperationException( + "Input is not rank 5. Got input with rank " + input.rank() + " " + layerId() + " with shape " + + Arrays.toString(input.shape()) + " - expected shape [minibatch,channels,depth,height,width]"); + if (labels == null) + throw new IllegalStateException("Labels are not set (null)"); + INDArray input2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), input, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray labels2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), labels, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray maskReshaped = ConvolutionUtils.reshapeCnn3dMask(layerConf().getDataFormat(), maskArray, input, workspaceMgr, ArrayType.FF_WORKING_MEM); @@ -279,4 +288,8 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, summedScores); } + + public ILossFunction getLossFn() { + return layerConf().getLossFn(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java index 06c3b237544c..c2338cd1affd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java @@ -198,6 +198,16 @@ public boolean needsLabels() { @Override public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(false); + if (input.rank() != 4) + throw new UnsupportedOperationException( + "Input is not rank 4. Got input with rank " + input.rank() + " " + layerId() + " with shape " + + Arrays.toString(input.shape()) + " - expected shape " + layerConf().getFormat().dimensionNames()); + if (labels == null) + throw new IllegalStateException("Labels are not set (null)"); + + Preconditions.checkState(input.equalShapes(labels), "Input and label arrays do not have same shape: %ndShape vs. %ndShape",input, labels); + INDArray input2d = ConvolutionUtils.reshape4dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray labels2d = ConvolutionUtils.reshape4dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM); INDArray maskReshaped = ConvolutionUtils.reshapeMaskIfRequired(maskArray, input, layerConf().getFormat(), workspaceMgr, ArrayType.FF_WORKING_MEM); @@ -248,4 +258,8 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, summedScores); } + + public ILossFunction getLossFn() { + return layerConf().getLossFn(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java index 4d118c62bf0d..d160854bdcf7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java @@ -680,4 +680,8 @@ public INDArray getProbabilityMatrix(INDArray networkOutput, int example, int cl INDArray conf = networkOutput.get(point(example), point(5*bbs + classNumber), all(), all()); return conf; } + + public ILossFunction getLossFn() { + return null; + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java index e6c14f2e4703..ec26baf7b18a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java @@ -18,6 +18,7 @@ import lombok.Getter; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -26,6 +27,9 @@ import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.api.buffer.DataType; @@ -34,6 +38,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -289,7 +294,7 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr return summedScores; } - public class OCNNLossFunction implements ILossFunction { + public class OCNNLossFunction extends BaseLossFunction { @Override public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index 28913681f67d..ced54fec3bda 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -213,6 +213,22 @@ public boolean needsLabels() { @Override public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) { + assertInputSet(false); + if (input.rank() != 3) + throw new UnsupportedOperationException( + "Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " + + input.rank() + " with shape " + Arrays.toString(input.shape()) + " for layer " + layerId()); + if (labels == null) + throw new IllegalStateException("Labels are not set (null)"); + + if (layerConf().getRnnDataFormat() == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + labels = labels.permute(0, 2, 1); + } + Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels); + Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" + + "Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels); + INDArray input = this.input; INDArray labels = this.labels; if (layerConf().getRnnDataFormat() == RNNFormat.NWC){ @@ -288,4 +304,8 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr return summedScores; } + + public ILossFunction getLossFn() { + return layerConf().getLossFn(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 979d2b7be23f..0bdf77cb63b2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -109,6 +109,12 @@ public Layer.Type type() { protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) { assertInputSet(false); if (input.rank() == 3) { + + RNNFormat format = layerConf().getRnnDataFormat(); + int td = (format == RNNFormat.NCW) ? 2 : 1; + Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels); + Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" + + "Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels); //Case when called from RnnOutputLayer INDArray inputTemp = input; input = (layerConf().getRnnDataFormat()==RNNFormat.NWC)? input.permute(0, 2, 1):input; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 7fc7af03f59f..496b4ccbe1f2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.samediff; +import lombok.NonNull; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; @@ -82,6 +83,23 @@ public SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String this.params = paramsView; } + @Override + public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, + SDVariable mask, @NonNull Map paramTable) { + Map inputMap = new HashMap<>(); + + //TODO input validation? +// config.validateInput(inputs); + + for(int i=0; i()); + } + @Override public String toString() { return null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 948837166f80..250d97e60f1f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -313,7 +313,7 @@ protected void doInit(){ long[] maskShape = ArrayUtil.nTimes((long)inputShape.length, -1); SDVariable mask = sameDiff.placeHolder(MASK_KEY, dataType, maskShape); - SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask); + SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, mask, params); Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null"); outputVar = layerOutput; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index 2b28fe952eee..7323434be42c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -321,7 +321,7 @@ protected void doInit(){ SDVariable v = sameDiff.var(s, dataType, ps); params.put(s, v); } - SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params); + SDVariable layerOutput = bl.defineLayerAndLoss(sameDiff, inputVar, labelVar, params); Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null"); outputVar = layerOutput; @@ -331,6 +331,7 @@ protected void doInit(){ } this.outputKey = layerOutput.name(); + sameDiff.setLossVariables(outputKey); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 2091babb0abf..8f3148da941a 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -17,6 +17,21 @@ package org.deeplearning4j.nn.multilayer; +import java.io.File; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; import lombok.Getter; import lombok.NonNull; import lombok.Setter; @@ -28,13 +43,26 @@ import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.api.Classifier; +import org.deeplearning4j.nn.api.FwdPassType; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.api.ModelAdapter; +import org.deeplearning4j.nn.api.NeuralNetwork; +import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.Updater; -import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; -import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerWithLoss; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -43,7 +71,9 @@ import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; +import org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater; import org.deeplearning4j.nn.updater.UpdaterCreator; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -55,8 +85,15 @@ import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.NetworkUtils; import org.deeplearning4j.util.OutputLayerUtil; +import org.deeplearning4j.util.ToSameDiffUtils; import org.nd4j.adapters.OutputAdapter; +import org.nd4j.autodiff.samediff.NameScope; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.primitives.Triple; +import org.nd4j.common.util.OneTimeLogger; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; @@ -85,16 +122,13 @@ import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; import org.nd4j.linalg.heartbeat.utils.TaskUtils; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.common.primitives.Pair; -import org.nd4j.common.primitives.Triple; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.NoOp; +import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.util.FeatureUtil; import org.nd4j.linalg.workspace.ND4JWorkspaceException; import org.nd4j.linalg.workspace.WorkspaceUtils; -import org.nd4j.common.util.OneTimeLogger; - -import java.io.*; -import java.util.*; ; @@ -755,6 +789,191 @@ public void init(INDArray parameters, boolean cloneParametersArray) { synchronizeIterEpochCounts(); } + /** + * + * Create the MultiLayerNetwork in a SameDiff instance. + * + * The input and lables placeholders are created with names "input" and "labels", respectively. + * Output and loss variables are set on the SameDiff instance and can be gotten from it. + * + * @param sameDiff The SameDiff instance to create the model in + * @param inputType The type of the input. May be null, in which case we try to use the previously set input type, or infer it if it hasn't been set. + * @param useView whether to directly use the (view) weights in the SDVariables, or create new ones. + * Using them saves an initialization (of every weight), but may cause issues with multi-gpu setups. + * @return The {@link org.nd4j.autodiff.samediff.TrainingConfig} if training is setup (the last layer is an BaseOutputLayer), or null if not. + */ + public org.nd4j.autodiff.samediff.TrainingConfig toSameDiff(@NonNull SameDiff sameDiff, InputType inputType, boolean useView, boolean skipErrors) { + + if(inputType == null){ + Preconditions.checkState(layerWiseConfigurations.getInputType() != null, "Must specify an input type or have it inferred for SameDiff conversion"); + inputType = layerWiseConfigurations.getInputType(); + } + + if (!isInitCalled()) + init(); + + SDVariable input = sameDiff + .placeHolder("input", getLayerWiseConfigurations().getDataType(), inputType.getShape(true)); + SDVariable currentOutput = input; + + InputType currentInputType = inputType; + + SDVariable sdOutputLabels = null; + + List layerNames = ToSameDiffUtils.getScopeNames(layers); + + for (int i = 0; i < layers.length; i++) { + Layer layer = layers[i]; + + if(layer instanceof SameDiffOutputLayer){ + if(i != layers.length - 1) + throw new IllegalStateException("A SameDiffOutputLayer must be the last layer in the model"); + } + + if (!(layer.getConfig() instanceof org.deeplearning4j.nn.conf.layers.Layer)) { + throw new UnsupportedOperationException("Can't convert non-Layer layers"); + } + + org.deeplearning4j.nn.conf.layers.Layer config = layerWiseConfigurations.getConf(i).getLayer(); + + //TODO use layer name if set + NameScope layerScope = sameDiff.withNameScope(layerNames.get(i)); + + // preprocessor + InputPreProcessor preProcessor = layerWiseConfigurations.getInputPreProcess(i); + + if (preProcessor != null) { + NameScope preProcessorScope = sameDiff.withNameScope("inputPreprocessor"); + + currentOutput = preProcessor.definePreProcess(sameDiff, currentOutput); + currentInputType = preProcessor.getOutputType(currentInputType); + preProcessorScope.close(); + } + + // create weights + + Map paramTable = ToSameDiffUtils.defineParams(sameDiff, layer, useView); + + if(config.getIDropout() != null){ + currentOutput = config.getIDropout().defineDropout(sameDiff, currentOutput); + } + + // layer + + if(config instanceof org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer){ + if(((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer) config).labelsRequired()) { + sdOutputLabels = sameDiff + .placeHolder("labels", getLayerWiseConfigurations().getDataType(), + config.getOutputType(i, currentInputType).getShape()); + } else { + sdOutputLabels = null; + } + + currentOutput = ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer) config).defineLayerAndLoss(sameDiff, currentOutput, sdOutputLabels, paramTable); + } else { + currentOutput = config.defineLayer(sameDiff, currentOutput, null, paramTable); + } + + currentInputType = config.getOutputType(i, currentInputType); + + layerScope.close(); + } + + sameDiff.setOutputs(currentOutput); + + org.deeplearning4j.nn.conf.layers.Layer lastLayer = getOutputLayer().conf().getLayer(); + if(lastLayer instanceof LayerWithLoss && getOutputLayer() instanceof IOutputLayer){ + // just use output + SDVariable labels; + if(((IOutputLayer) getOutputLayer()).needsLabels()) { + labels = sameDiff + .placeHolder("labels", getLayerWiseConfigurations().getDataType(), + currentInputType.getShape(true)); + } else { + labels = null; + } + + NameScope layerScope = sameDiff.withNameScope(layerNames.get(layerNames.size() - 1)); + NameScope lossScope = sameDiff.withNameScope("loss"); + + SDVariable loss = ((LayerWithLoss) lastLayer).defineLoss(sameDiff, currentOutput, labels, conf().isMiniBatch()); + lossScope.close(); + layerScope.close(); + loss.rename("loss"); + + sameDiff.setLossVariables(loss); + + + IUpdater iUpdater = ToSameDiffUtils.getUpdater(layers, skipErrors); + List regularizations = ToSameDiffUtils.getRegularizations(layers, skipErrors); + + org.nd4j.autodiff.samediff.TrainingConfig.Builder tcBuilder = org.nd4j.autodiff.samediff.TrainingConfig.builder() + .minimize(loss.name()) + .minimize(conf().isMinimize()) + .dataSetFeatureMapping(input.name()); + + if(iUpdater != null) + tcBuilder.updater(iUpdater.clone()); + else + tcBuilder.updater(new NoOp()); + + if(labels != null) + tcBuilder.dataSetLabelMapping(labels.name()); + else + tcBuilder.markLabelsUnused(); + + if(regularizations != null) + tcBuilder.regularization(regularizations); + + org.nd4j.autodiff.samediff.TrainingConfig trainingConfig = tcBuilder.build(); + + trainingConfig.setIterationCount(getIterationCount()); + trainingConfig.setEpochCount(getEpochCount()); + + sameDiff.setTrainingConfig(trainingConfig); + + if(iUpdater != null) { + Updater updater = getUpdater(); + if(updater instanceof BaseMultiLayerUpdater){ + ToSameDiffUtils.copyUpdaterState(sameDiff, (BaseMultiLayerUpdater) updater, layers); + } else { + if(skipErrors) + log.warn("Unsupported updater type {}, not copying updater state to SameDiff", updater.getClass().getSimpleName()); + else + throw new IllegalStateException("Unsupported updater type " + updater.getClass().getSimpleName() + ", could not updater state to SameDiff"); + } + } + + return trainingConfig; + } + + return null; + } + + /** + * See {@link #toSameDiff(SameDiff, InputType, boolean, boolean)}. {@code useView} and {@code skipErrors} are true. + */ + public org.nd4j.autodiff.samediff.TrainingConfig toSameDiff(@NonNull SameDiff sameDiff, InputType inputType){ + return toSameDiff(sameDiff, inputType, true, true); + } + + /** + * See {@link #toSameDiff(SameDiff, InputType, boolean, boolean)}. + * @return A new SameDiff instance with this model defined in it. The output variable is set, as is the loss variable and {@link org.nd4j.autodiff.samediff.TrainingConfig} if the last layer is an {@link org.deeplearning4j.nn.conf.layers.BaseOutputLayer}. + */ + public SameDiff toSameDiff(InputType inputType, boolean useView, boolean skipErrors){ + SameDiff sameDiff = SameDiff.create(); + toSameDiff(sameDiff, inputType, useView, skipErrors); + return sameDiff; + } + + /** + * See {@link #toSameDiff(SameDiff, InputType, boolean, boolean)}. {@code useView} and {@code skipErrors} are true. + */ + public SameDiff toSameDiff(InputType inputType){ + return toSameDiff(inputType, true, true); + } + /** * This method allows you to specificy GradientsAccumulator instance to be used with this model
*
@@ -1062,6 +1281,7 @@ protected synchronized List ffToLayerActivationsDetached(boolean train //Validation: Exception if invalid (bad layer implementation) validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, "Feed forward to layer (inference)"); + // some activations overwrite their inputs out.add(input); } if(clearInputs) { @@ -1163,6 +1383,7 @@ protected synchronized List ffToLayerActivationsInWs(int layerIndex, @ validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, "Feed forward to layer (training)"); validateArrayWorkspaces(workspaceMgr, layers[i].input(), ArrayType.INPUT, i, false, "Feed forward to layer (training)"); + // some activations overwrite their inputs out.add(input); if(traceLog){ @@ -2763,17 +2984,24 @@ public void computeGradientAndScore() { .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); //Validate activations location } - getOutputLayer().setInput(inputToOutputLayer, mgr); - //Then: compute gradients - Pair pair = calcBackpropGradients(null, true, false, false); - this.gradient = (pair == null ? null : pair.getFirst()); + IOutputLayer outputLayer = (IOutputLayer) getOutputLayer(); + outputLayer.setInput(inputToOutputLayer, mgr); + if (labels == null && outputLayer.needsLabels()) + throw new IllegalStateException("No labels found"); + outputLayer.setLabels(labels); + + //Some gradient methods overwrite their inputs, so calculate the score first //Calculate score try(MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { double r = calcRegularizationScore(true); score = ((IOutputLayer) getOutputLayer()).computeScore(r, true, mgr); } + //Then: compute gradients + Pair pair = calcBackpropGradients(null, true, false, false); + this.gradient = (pair == null ? null : pair.getFirst()); + //Listeners if (!trainingListeners.isEmpty()) { try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ToSameDiffUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ToSameDiffUtils.java new file mode 100644 index 000000000000..a31783098944 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ToSameDiffUtils.java @@ -0,0 +1,397 @@ +/* + * ****************************************************************************** + * * Copyright (c) 2020 Konduit K.K. + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.util; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.Trainable; +import org.deeplearning4j.nn.api.Updater; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.graph.vertex.GraphVertex; +import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater; +import org.deeplearning4j.nn.updater.UpdaterBlock; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.learning.GradientUpdater; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.NoOp; +import org.nd4j.linalg.learning.regularization.Regularization; + +/** + * Utilities for use in {@link org.deeplearning4j.nn.graph.ComputationGraph#toSameDiff(SameDiff, Map, boolean, boolean)} and {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork#toSameDiff(SameDiff, InputType, boolean, boolean)}. + */ +@Slf4j +public class ToSameDiffUtils { + + + /** + * Get the updater for a network. If updaters aren't the same on all layers, throws an exception or returns null depending on skipErrors. + * @param layers The layers of the network. + * @param skipErrors If true, returns null if updaters aren't the same for all layers. Otherwise, throws an error. + */ + public static IUpdater getUpdater(Layer[] layers, boolean skipErrors){ + IUpdater iUpdater = null; + for(Layer l : layers) { + org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer(); + if (conf instanceof BaseLayer) { + IUpdater u = ((BaseLayer) conf).getIUpdater(); + if (iUpdater == null) { + iUpdater = u; + } else { + if (u != null && !u.equals(iUpdater)) { + if (skipErrors) { + log.warn("Ignoring updater config: Can not convert to SameDiff with different IUpdaters. Expected {}, but was {} for {}", iUpdater, u, conf); + return null; + } else { + throw new IllegalStateException( + "Can not convert to SameDiff with different IUpdaters. Ensure all layers have the same updater. Expected " + + iUpdater + ", but was " + u + " different for " + conf); + } + } + } + + u = ((BaseLayer) conf).getBiasUpdater(); + if (iUpdater == null) { + iUpdater = u; + } else { + if (u != null && !u.equals(iUpdater)) { + if (skipErrors) { + log.warn("Ignoring updater config: Can not convert to SameDiff when layers have different IUpdaters. Expected {}, but was {} for {}", iUpdater, u, conf); + return null; + } else { + throw new IllegalStateException( + "Can not convert to SameDiff with different IUpdaters. Ensure all layers have the same updater. Expected " + + iUpdater + ", but was " + u + " for " + conf); + } + } + } + } + } + return iUpdater; + } + + /** + * Get the regularizations of a network. If regularizations aren't the same on all layers, throws an exception or returns null depending on skipErrors. + * @param layers The layers of the network. + * @param skipErrors If true, returns null if regularizations aren't the same for all layers. Otherwise, throws an error. + */ + public static List getRegularizations(Layer[] layers, boolean skipErrors){ + List regularizations = null; + + for(Layer l : layers){ + org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer(); + if(conf instanceof BaseLayer){ + if(regularizations == null){ + regularizations = ((BaseLayer) conf).getRegularization(); + } else { + if(!((BaseLayer) conf).getRegularization().equals(regularizations)) { + if(skipErrors){ + log.warn("Ignoring regularization config: Can not convert to SameDiff when layers have different regularizations. Expected {}, but was {} for {}", + regularizations, ((BaseLayer) conf).getRegularization(), conf); + return null; + } else { + throw new IllegalStateException( + "Can not convert to SameDiff with different regularizations. Ensure all layers have the same regularizations, and that bias and weight regularizations are the same. " + + "Expected " + regularizations + ", but was " + ((BaseLayer) conf) + .getRegularization() + " for " + conf); + } + } + } + + if(regularizations == null){ + regularizations = ((BaseLayer) conf).getRegularizationBias(); + } else { + if(!((BaseLayer) conf).getRegularizationBias().equals(regularizations)) { + if(skipErrors){ + log.warn("Ignoring regularization config: Can not convert to SameDiff when layers have different regularizations. Expected {}, but was {} for {}", + regularizations, ((BaseLayer) conf).getRegularization(), conf); + return null; + } else { + throw new IllegalStateException( + "Can not convert to SameDiff with different regularizations. Ensure all layers have the same regularizations, and that bias and weight regularizations are the same. " + + "Expected " + regularizations + ", but was " + ((BaseLayer) conf) + .getRegularizationBias() + " for bias in " + conf); + } + } + } + } + } + return regularizations; + } + + /** + * Define the parameters of a layer, transforming them if necessary using {@link org.deeplearning4j.nn.conf.layers.Layer#transformParamsForSameDiff(Map)}. + * + * @param sameDiff The SameDiff to define the parameters in. + * @param layer The layer whose parameters we are defining. + * @param useView Whether to use the param view directly (if true) or dup it. + * @return The SDVariable parameters of the layer. + */ + public static Map defineParams(SameDiff sameDiff, Layer layer, boolean useView){ + Map params = new HashMap<>(layer.paramTable(false)); + layer.conf().getLayer().transformParamsForSameDiff(params); + return defineTransformedParams(sameDiff, params, (int) layer.numParams(), useView); + } + + + /** + * Define the parameters of a vertex, transforming them if necessary using {@link GraphVertex#transformParamsForSameDiff(Map)}. + * + * @param sameDiff The SameDiff to define the parameters in. + * @param vertex The vertex whose parameters we are defining. + * @param useView Whether to use the param view directly (if true) or dup it. + * @return The SDVariable parameters of the vertex. + */ + public static Map defineParams(SameDiff sameDiff, GraphVertex vertex, boolean useView){ + Map params = new HashMap<>(vertex.paramTable(false)); + vertex.transformParamsForSameDiff(params); + return defineTransformedParams(sameDiff, params, (int) vertex.numParams(), useView); + } + + /** + * A helper for parameter definition. + */ + private static Map defineTransformedParams(SameDiff sameDiff, Map params, int numParams, boolean useView){ + Map newParams = new HashMap<>(numParams); + for (Map.Entry entry : params.entrySet()) { + INDArray value = entry.getValue(); + if (!useView) { + value = value.dup(); + } + newParams.put(entry.getKey(), sameDiff.var(entry.getKey(), value)); + } + return newParams; + } + + public static List getScopeNames(Layer[] layers){ + List names = new ArrayList<>(); + Map numLayers = new HashMap<>(); + + for (Layer layer : layers) { + org.deeplearning4j.nn.conf.layers.Layer config = layer.conf().getLayer(); + String baseName = config.getLayerName() == null ? config.getClass().getSimpleName() : config.getLayerName(); + + int layerNum = 0; + + if (numLayers.containsKey(baseName)) { + layerNum = numLayers.get(baseName); + numLayers.put(baseName, ++layerNum); + } else { + numLayers.put(baseName, 0); + } + names.add(baseName + (layerNum == 0 ? "" : "_" + layerNum)); + } + + return names; + } + + /** + * Copy the state from a MultiLayerNetwork or ComputationGraph updater to a SameDiff instance. + * @param sameDiff The SameDiff to copy to. + * @param updater The updater to copy from. + * @param layers The layers of the network or graph. + */ + public static void copyUpdaterState(@NonNull SameDiff sameDiff, BaseMultiLayerUpdater updater, Layer[] layers){ + if(updater == null) + return; + + List layerList = null; + List layerNames = null; + if(layers != null) { + layerNames = getScopeNames(layers); + layerList = Arrays.asList(layers); + } + + // layer -> param -> updater param -> array + Map>> layerParamStates = new HashMap<>(); + for(UpdaterBlock ub : updater.getUpdaterBlocks()){ + List params = ub.getLayersAndVariablesInBlock(); + int blockPStart = ub.getParamOffsetStart(); + int blockPEnd = ub.getParamOffsetEnd(); + + int blockUStart = ub.getUpdaterViewOffsetStart(); + int blockUEnd = ub.getUpdaterViewOffsetEnd(); + + int paramsMultiplier = (blockUEnd-blockUStart)/(blockPEnd-blockPStart); //Updater state length should be exactly 0, 1, 2 or 3x number of params + + INDArray updaterView = ub.getUpdaterView(); + long nParamsInBlock = blockPEnd - blockPStart; + + long soFar = 0; + for( int sub=0; sub state = ub.getGradientUpdater().getState(); + + long offsetWithinSub = 0; + for (UpdaterBlock.ParamState ps : params) { + + String namespace; + if(ps.getLayer() instanceof GraphVertex){ + namespace = ((GraphVertex) ps.getLayer()).getVertexName(); + } else { + Layer layer = (Layer) ps.getLayer(); + namespace = layerNames.get(layerList.indexOf(layer)); + } + + String paramName = namespace + "/" + ps.getParamName(); + + INDArray pv = ps.getParamView(); + long nParamsThisParam = pv.length(); + + + Map paramState = new HashMap<>(); + for(String k : state.keySet()){ + paramState.put(k, state.get(k).get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(offsetWithinSub, offsetWithinSub + nParamsThisParam))); + } + + Map> layerState = layerParamStates.get(ps.getLayer()); + + if(layerState == null){ + layerState = new HashMap<>(); + layerParamStates.put(ps.getLayer(), layerState); + } + + offsetWithinSub += nParamsThisParam; + layerState.put(ps.getParamName(), paramState); + } + + soFar += nParamsInBlock; + } + } + + // transform gradient params like weights are transformed + + Map> paramUpdaterStates = new HashMap<>(); + for(Map.Entry>> entry : layerParamStates.entrySet()){ + Trainable trainable = entry.getKey(); + Map> byParam = entry.getValue(); + + String namespace; + if(trainable instanceof GraphVertex){ + namespace = ((GraphVertex) trainable).getVertexName(); + } else { + Layer layer = (Layer) trainable; + namespace = layerNames.get(layerList.indexOf(layer)); + } + + namespace += "/"; + + if(entry.getValue().isEmpty()) + continue; + + // updaterParam -> param -> arr + // need param -> arr to feed to transform methods + Map> byUpdaterParam = new HashMap<>(); + for(String param : byParam.keySet()){ + for(Map.Entry uEntry : byParam.get(param).entrySet()){ + String updaterParam = uEntry.getKey(); + + Map updaterParamMap = byUpdaterParam.get(updaterParam); + if(updaterParamMap == null){ + updaterParamMap = new HashMap<>(); + byUpdaterParam.put(updaterParam, updaterParamMap); + } + + String sdName = namespace + param; + SDVariable v = sameDiff.getVariable(sdName); + INDArray sdArr = v.getArr(); + + updaterParamMap.put(param, uEntry.getValue().dup().reshape(sdArr.ordering(), sdArr.shape())); + } + } + + for(String updaterParam : byUpdaterParam.keySet()){ + Map byParamMap = byUpdaterParam.get(updaterParam); + if(trainable instanceof GraphVertex){ + ((GraphVertex) trainable).transformParamsForSameDiff(byParamMap); + } else { + Layer layer = (Layer) trainable; + layer.conf().getLayer().transformParamsForSameDiff(byParamMap); + } + + for(String param : byParam.keySet()){ + String sdName = namespace + param; + Map finalByUpdaterParam = paramUpdaterStates.get(sdName); + if(finalByUpdaterParam == null){ + finalByUpdaterParam = new HashMap<>(); + paramUpdaterStates.put(sdName, finalByUpdaterParam); + } + + finalByUpdaterParam.put(updaterParam, byParamMap.get(param)); + } + } + + } + + if (sameDiff.getTrainingConfig() == null) { + throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig"); + } + + Map> updaterMap = new HashMap<>(); + for (Variable v : sameDiff.getVariables().values()) { + if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) { + //Skip non-trainable parameters + continue; + } + + INDArray arr = v.getVariable().getArr(); + long stateSize = sameDiff.getTrainingConfig().getUpdater().stateSize(arr.length()); + + Map params; + if(stateSize > 0) { + if (paramUpdaterStates.containsKey(v.getVariable().name())) { + params = paramUpdaterStates.get(v.getVariable().name()); + } else { + throw new IllegalStateException("No updater state found for variable " + v.getVariable().name()); + } + } else { + params = new HashMap<>(); + } + + for(String k : params.keySet()){ + params.put(k, params.get(k)); + } + + GradientUpdater gu = sameDiff.getTrainingConfig().getUpdater().instantiate(params, false); +// gu.setState(params, false); + updaterMap.put(v.getName(), gu); + } + + sameDiff.initializeTraining(updaterMap); + + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java index 058789aa22fb..eb98c9e3b0fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java @@ -36,12 +36,20 @@ public enum IndexType{ public SDIndex(){} - + + /** + * Create an index that gets the entire dimension. + */ public static SDIndex all(){ return new SDIndex(); } - + + /** + * Create a point index. The dimension will not be kept. + * Negative values are supported, and interpreted as being from the end (like Python indexing). + * @param i The index to get. + */ public static SDIndex point(long i){ SDIndex sdIndex = new SDIndex(); sdIndex.indexType = IndexType.POINT; @@ -51,6 +59,12 @@ public static SDIndex point(long i){ } + /** + * Create a point index. + * Negative values are supported, and interpreted as being from the end (like Python indexing). + * @param i The index to get. + * @param keepDim Whether to keep the dimension. + */ public static SDIndex point(long i, boolean keepDim){ SDIndex sdIndex = new SDIndex(); sdIndex.indexType = IndexType.POINT; @@ -59,6 +73,12 @@ public static SDIndex point(long i, boolean keepDim){ return sdIndex; } + /** + * Create an interval index with a stride of 1. + * Negative values are supported, and interpreted as being from the end (like Python indexing). + * @param begin The beginning of the interval. + * @param end The end of the interval (exclusive). + */ public static SDIndex interval(Long begin, Long end){ SDIndex sdIndex = new SDIndex(); sdIndex.indexType = IndexType.INTERVAL; @@ -67,6 +87,13 @@ public static SDIndex interval(Long begin, Long end){ return sdIndex; } + + /** + * Create an interval index with a stride of 1. + * Negative values are supported, and interpreted as being from the end (like Python indexing). + * @param begin The beginning of the interval. + * @param end The end of the interval (exclusive). + */ public static SDIndex interval(Integer begin, Integer end){ SDIndex sdIndex = new SDIndex(); sdIndex.indexType = IndexType.INTERVAL; @@ -79,6 +106,14 @@ public static SDIndex interval(Integer begin, Integer end){ return sdIndex; } + + /** + * Create an interval index. + * Negative endpoints are supported, and interpreted as being from the end (like Python indexing). + * @param begin The beginning of the interval. + * @param strides The stride of the interval. + * @param end The end of the interval (exclusive). + */ public static SDIndex interval(Long begin, Long strides, Long end){ if(strides == 0){ throw new ND4JIllegalArgumentException("Invalid index : strides can not be 0."); @@ -91,6 +126,13 @@ public static SDIndex interval(Long begin, Long strides, Long end){ return sdIndex; } + /** + * Create an interval index. + * Negative endpoints are supported, and interpreted as being from the end (like Python indexing). + * @param begin The beginning of the interval. + * @param strides The stride of the interval. + * @param end The end of the interval (exclusive). + */ public static SDIndex interval(Integer begin, Integer strides, Integer end){ if(strides == 0){ throw new ND4JIllegalArgumentException("Invalid index : strides can not be 0."); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index d78c6b5b36d9..7a56b5ea2cc6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -1586,7 +1586,9 @@ public SDVariable convertToVariable(){ /** * Rename this variable to a new name. Equivalent to {@link SameDiff#renameVariable(String, String)} * - * @param newName The new name for the variable - no variable with this name must already exist + * See {@link #rename(String, boolean)}. + * + * @param newName The new name for the variable - no variable with this name must already exist in this scope * @return The current variable (same object) */ public SDVariable rename(String newName) { @@ -1594,6 +1596,17 @@ public SDVariable rename(String newName) { return this; } + /** + * Rename this variable to a new name. Equivalent to {@link SameDiff#renameVariable(String, String, boolean)} + * + * @param newName The new name for the variable - no variable with this name must already exist (in this scope if includeScope is true) + * @return The current variable (same object) + */ + public SDVariable rename(String newName, boolean includeScope) { + sameDiff.renameVariable(getVarName(), newName, includeScope); + return this; + } + /** * Mark this variable as a loss function variable. This means that this variable will be minimized via backprop during training.
* This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 1535ed105e43..76b345e13957 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -362,7 +362,7 @@ public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull A } /** - * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. + * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. Does not include a '/' postfix. */ public String currentNameScope() { if (nameScopes.isEmpty()) @@ -476,8 +476,9 @@ public List getOpsInScope(String scope){ */ public List getVariablesInScope(NameScope scope) { ArrayList vars = new ArrayList<>(); + String scopeName = scope.getName() + "/"; for (SDVariable v : variables()) { - if (v.name().startsWith(scope.getName())) + if (v.name().startsWith(scopeName)) vars.add(v); } return vars; @@ -1312,6 +1313,34 @@ public List outputs() { return this.outputs; } + + + /** + * See {@link #setOutputs(List)} + */ + public void setOutputs(SDVariable... outputs){ + setVariableOutputs(outputs == null ? null : Arrays.asList(outputs)); + } + + + /** + * See {@link #setOutputs(List)} + */ + public void setVariableOutputs(List outputs){ + + if(outputs != null){ + List names = new ArrayList<>(outputs.size()); + for(SDVariable sdv : outputs){ + Preconditions.checkArgument(sdv.sameDiff == this, "Can't set output to SDVariable in different SameDiff instance."); + names.add(sdv.name()); + } + + setOutputs(names); + } else { + setOutputs((List) null); + } + } + /** * See {@link #setOutputs(List)} */ @@ -1330,7 +1359,7 @@ public void setOutputs(String... outputs){ public void setOutputs(List outputs){ if(outputs != null){ for(String s : outputs){ - Preconditions.checkArgument(variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name"); + Preconditions.checkArgument(variables.containsKey(s), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name", s); } } this.outputs = outputs; @@ -1349,7 +1378,7 @@ public List variables() { /** * Get the names of variables (if any) that have been marked as loss variables to be minimized.
* Variables can be marked as loss variables in a few different ways:
- * (a) Losses are automatically added when creating loss functions via {@link #sd()}
+ * (a) Losses are automatically added when creating loss functions via {@link #loss()}
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
* (c) Via {@link TrainingConfig#setLossVariables(List)}
*/ @@ -1889,7 +1918,8 @@ protected void initializeTraining() { if (trainingConfig == null) { throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig"); } - updaterMap = new HashMap<>(); + + Map> updaterMap = new HashMap<>(); for (Variable v : variables.values()) { if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) { //Skip non-trainable parameters @@ -1899,15 +1929,41 @@ protected void initializeTraining() { INDArray arr = v.getVariable().getArr(); long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize); - GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false); + GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false); gu.setStateViewArray(view, arr.shape(), arr.ordering(), true); updaterMap.put(v.getName(), gu); } - initializedTraining = true; + initializeTraining(updaterMap); } } + /** + * Initalize training with the specified gradient updaters. Overwrites the existing gradient updaters if there are any. + * @param gradientUpdaters The gradient updaters to use. Must specify one for each trainable variable. + */ + public void initializeTraining(Map> gradientUpdaters) { + if (trainingConfig == null) { + throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig"); + } + updaterMap = new HashMap<>(); + for (Variable v : variables.values()) { + if (v.getVariable().getVariableType() != VariableType.VARIABLE || !v.getVariable().dataType().isFPType()) { + //Skip non-trainable parameters + continue; + } + + GradientUpdater gu = gradientUpdaters.get(v.getVariable().name()); + + if(gu == null) + throw new IllegalArgumentException("Must specify a gradient updater for each trainable parameter: missing " + v.getVariable().name()); + + updaterMap.put(v.getName(), gu); + } + + initializedTraining = true; + } + /** * Convert the MultiDataSet to a {@code Map} based on the TrainingConfig settings. * The key is the placeholder/variable that the value INDArray should be associated with. @@ -2713,9 +2769,8 @@ public SDVariable constant(String name, @NonNull INDArray constant) { * @return SDVariable placeholder */ public SDVariable placeHolder(@NonNull String name, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - Preconditions.checkState(!variables.containsKey(name), "Variable already exists with name %s", name); SDVariable ret = new SDVariable(name, VariableType.PLACEHOLDER, this, shape, dataType); - variables.put(name, Variable.builder().name(name).variable(ret).build()); + addVariable(ret); return ret; } @@ -3317,13 +3372,30 @@ public void convertDataTypes(@NonNull Map dataTypeMap) { } } + /** + * Rename the specified variable to the new name, adding the current scope to the new name. + * + * See {@link #renameVariable(String, String, boolean)}. + * + * @param from The variable to rename - this variable must exist + * @param to The new name for the variable - no variable with this name must already exist in this scope + */ + public void renameVariable(String from, String to){ + renameVariable(from, to, true); + } + /** * Rename the specified variable to the new name. * * @param from The variable to rename - this variable must exist - * @param to The new name for the variable - no variable with this name must already exist + * @param to The new name for the variable - no variable with this name must already exist (in this scope if includeScope is true) + * @param includeScope Whether to add the current NameScope to the new name. True by default. */ - public void renameVariable(String from, String to) { + public void renameVariable(String from, String to, boolean includeScope) { + + if(includeScope && currentNameScope() != null) + to = currentNameScope() + "/" + to; + Preconditions.checkState(variables.containsKey(from), "Cannot rename variable \"%s\": no variable with this name exists", from); Preconditions.checkState(!variables.containsKey(to), "Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", from, to, to); @@ -3772,7 +3844,6 @@ public SDVariable addVariable(SDVariable variable) { throw new IllegalArgumentException("Variable with name \"" + variable.name() + "\" already exists"); } - Preconditions.checkState(variable.getSameDiff() == this, "Same diff instance for variable must be the same!"); variables.put(variable.name(), Variable.builder().name(variable.name()).variable(variable).build()); return variable; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/BaseActivationFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/BaseActivationFunction.java index ccdb119e6b2d..bc0bf3655e12 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/BaseActivationFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/BaseActivationFunction.java @@ -16,6 +16,9 @@ package org.nd4j.linalg.activations; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -38,4 +41,8 @@ protected void assertShape(INDArray in, INDArray epsilon){ + ", epsilon.shape() = " + Arrays.toString(epsilon.shape())); } } + + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input){ + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName()); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/IActivation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/IActivation.java index e67a99c8a39b..144fbc295127 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/IActivation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/IActivation.java @@ -16,6 +16,9 @@ package org.nd4j.linalg.activations; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; import org.nd4j.serde.json.LegacyIActivationDeserializerHelper; @@ -25,7 +28,8 @@ import java.io.Serializable; /** - * Interface for implementing custom activation functions + * Interface for implementing custom activation functions. + * Custom activation functions should probably extend {@link BaseActivationFunction} instead of this interface. */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", defaultImpl = LegacyIActivationDeserializerHelper.class) @@ -59,4 +63,15 @@ public interface IActivation extends Serializable { int numParams(int inputSize); + /** + * Define the activation function for conversion to {@link SameDiff}.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} + * like the default implementation in {@link BaseActivationFunction}. + * + * @param sameDiff The SameDiff instance to define in. + * @param input The input to the activation function. + * @return The output of the activation function. + */ + SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java index c1c938a9895d..4d365106371d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java @@ -19,6 +19,8 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; @@ -47,6 +49,11 @@ public Pair backprop(@NonNull INDArray in, @NonNull INDArray return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.math.cube(input); + } + @Override public String toString() { return "cube"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index 71f31cba5e71..25816fc8bf80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; @@ -68,6 +71,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.elu(input); + } + @Override public String toString() { return "elu(alpha=" + alpha + ")"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java index 953cb258763d..47130bae6ebc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU; @@ -68,6 +71,14 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(dLdz, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + if(precise) + return sameDiff.nn.preciseGelu(input); + else + return sameDiff.nn.gelu(input); + } + @Override public String toString() { return "gelu(precise="+precise+")"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java index 623757bb0617..61c40a3f79d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; @@ -46,6 +49,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.hardSigmoid(input); + } + @Override public String toString() { return "hardsigmoid"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java index cc11a38106cc..ab4a6a2d2ee2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; @@ -49,6 +52,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.hardTanh(input); + } + @Override public String toString() { return "hardtanh"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java index f86f1c21ef80..4427177e0c15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,6 +44,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(epsilon, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return input; + } + @Override public String toString() { return "identity"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java index 8e0faf160910..9a9e68c1010f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; @@ -60,6 +63,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.leakyRelu(input, alpha); + } + @Override public String toString() { return "leakyrelu(a=" + alpha + ")"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java index 7ab79d81fb54..f7b25a65e9bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java @@ -19,6 +19,10 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.util.ArrayUtil; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -78,6 +82,12 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, dLdalpha); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + SDVariable alpha = sameDiff.var("alpha", getAlpha().dup()); + return sameDiff.nn.prelu(input, alpha, sharedAxes == null ? new int[]{} : ArrayUtil.toInts(sharedAxes)); + } + @Override public String toString() { return "prelu"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java index dd0bf65662bc..f76c6578b62d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; @@ -54,6 +57,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.math.rationalTanh(input); + } + @Override public String toString() { return "rationaltanh"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java index 52cc5015ef35..13b6df44eb83 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.scalar.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.common.primitives.Pair; @@ -101,6 +104,33 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(dLdz, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + SDVariable temp; + double thresh = threshold == null ? 0.0 : threshold; + double ns = negativeSlope == null ? 0.0 : negativeSlope; + if(ns == 0){ + temp = sameDiff.nn.relu(input, thresh); + } else { + if(thresh == 0) + temp = sameDiff.nn.leakyRelu(input, negativeSlope); + else { + //TODO optimize this + SDVariable t = sameDiff.constant(Nd4j.scalar(input.dataType(), thresh)); + SDVariable oneGte = input.gte(t).castTo(input.dataType()); + SDVariable oneLt = input.lt(t).castTo(input.dataType()); + SDVariable lower = oneLt.mul(ns).mul(input.sub(threshold)); + SDVariable upper = oneGte.mul(input); + temp = lower.add(upper); + } + } + + if(max != null) + temp = sameDiff.math.max(sameDiff.constant(Nd4j.scalar(temp.dataType(), max)), temp); + + return temp; + } + @Override public String toString() { return "relu"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java index d5e0cb76cf37..41b727a7639c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.Relu6; @@ -47,6 +50,12 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + // 0 is the default cutoff in the op + return sameDiff.nn.relu6(input, 0); + } + @Override public String toString() { return "relu6"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java index 9764e55d2e0e..c6562921e0c7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; @@ -51,6 +54,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.math.rectifiedTanh(input); + } + @Override public String toString() { return "rectifiedtanh"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java index d7f71e38285f..53c3fb9470a1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; @@ -47,6 +50,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.selu(input); + } + @Override public String toString() { return "selu"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java index de2db231a3d7..df5b73cea198 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; @@ -47,6 +50,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.sigmoid(input); + } + @Override public String toString() { return "sigmoid"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java index 6f2c82ba7850..abf873c26b03 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp; @@ -47,6 +50,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.softplus(input); + } + @Override public String toString() { return "softplus"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java index aefe0297ee22..12be09104d60 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; @@ -47,6 +50,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.softsign(input); + } + @Override public String toString() { return "softsign"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java index dfcabe2699bb..180dc0a2d573 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; @@ -53,4 +56,8 @@ public String toString() { return "softmax"; } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.softmax(input); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSwish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSwish.java index 793181112f4a..dd48ed7ffb04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSwish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSwish.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish; @@ -46,6 +49,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(dLdz, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.swish(input); + } + @Override public String toString() { return "swish"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java index 2b29bada4352..dc3fe1721aaa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java @@ -18,6 +18,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; @@ -47,6 +50,11 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + return sameDiff.nn.tanh(input); + } + @Override public String toString() { return "tanh"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java index 5eb1c3e9c888..e7d03effde9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java @@ -18,7 +18,11 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.BaseActivationFunction; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -64,6 +68,12 @@ public Pair backprop(INDArray in, INDArray epsilon) { return new Pair<>(in, null); } + @Override + public SDVariable defineActivation(@NonNull SameDiff sameDiff, @NonNull SDVariable input) { + //TODO the mul works around a bug in relu, should only need the relu call https://github.com/eclipse/deeplearning4j/issues/9018 + return sameDiff.nn.relu(input, theta).mul(input.gt(theta).castTo(input.dataType())); + } + @Override public String toString() { return "thresholdedrelu(theta=" + theta + ")"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index f1bafe0de7ec..cf1932d7a4d2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -152,18 +152,23 @@ public BaseReduceOp(SameDiff sameDiff) { @Override public INDArray noOp() { + return noOp(x, z); + } + + @Override + public INDArray noOp(INDArray x, INDArray z) { if (z != null && x != z) - return z().assign(x); + return z.assign(x); else { //Need to take into account shapes: for example, [1,3].sum(0) -> [3] //Or [1,1,1,1].sum(0,2,3) -> [1] if(keepDims){ - return x().dup(x().ordering()); + return x.dup(x.ordering()); } else { long[] shape = x.shape(); if(dimensions == null || Shape.isWholeArray(shape, dimensions)){ //Return scalar - return x.reshape().dup(); + return x.reshape(-1).dup(); } else { //Strip out size 1 dimensions long[] outShape = ArrayUtil.removeIndex(shape, dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java index 23d81c5b4895..3e3942d04930 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/ReduceOp.java @@ -62,6 +62,8 @@ public interface ReduceOp extends Op { */ INDArray noOp(); + INDArray noOp(INDArray x, INDArray z); + /** * This method returns dimensions for this op * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/BaseLossFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/BaseLossFunction.java new file mode 100644 index 000000000000..cb5427b02996 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/BaseLossFunction.java @@ -0,0 +1,117 @@ +/* + * ****************************************************************************** + * * Copyright (c) 2020 Konduit K.K. + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.lossfunctions; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.ops.SDLoss; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * This class can be extended in two ways:
+ *
    + *
  • Implement {@code defineLoss}. It will be called to define the loss function. You must handle averaging yourself.
  • + *
  • Implement {@code defineLossArray}. It will be called from the default implementation of {@code defineLoss}, which automatically handles averaging. + * The use of {@link SDLoss} ops or any ops that don't accept output gradients is FORBIDDEN. + * If you want to use those ops you must use the other extension method.
  • + *
+ */ +public abstract class BaseLossFunction implements ILossFunction { + + /** + * Helper function to sum or average a loss array depending on the parameter.
+ * + * Should not be used from {@link #defineLossArray(SameDiff, SDVariable, SDVariable)} baring special circumstances, + * as the averaging is handled automatically (using this function). + * + * @param output The loss array. + * @param labels The labels, used to find the batch size. + * @param average Whether to average the array (sums if false). + * @return The scalar average or sum, depending on the parameter. + */ + protected static SDVariable reduceLossArray(SDVariable output, SDVariable labels, boolean average){ + output = output.sum(); + if(average) //TODO without cast, only fails on backprop + return output.div(output.getSameDiff().sizeAt(labels, 0).castTo(output.dataType())); + else + return output; + } + + /** + * Helper function to apply a weight to the loss.
+ * Should only be used from {@link #defineLossArray(SameDiff, SDVariable, SDVariable)} baring special circumstances. + * + * @param loss The loss array. + * @param weight The weight. + */ + protected static SDVariable multiplyWeight(@NonNull SDVariable loss, INDArray weight){ + return LossUtil.multiplyWeight(loss, weight); + } + + /** + * Define the loss function for a {@link SameDiff} instance by defining a per-example score array, which is averaged automatically if necessary.
+ * + * The default implementation of {@link #defineLoss(SameDiff, SDVariable, SDVariable, boolean)} will call this method, + * so a subclass of {@link BaseLossFunction} can define only this method.
+ * + * However, when using this method, the use of {@link SDLoss} ops and any other ops that don't accept an output gradient is FORBIDDEN. + * This is due to the fact that we have to do sum/average ops to the output of this function. + * If those ops are nessecary, implement {@link #defineLoss(SameDiff, SDVariable, SDVariable, boolean)} instead.
+ * + * @param sameDiff The {@link SameDiff} instance + * @param input The input to the loss function, typically the output of the previous layer. + * @param labels The labels to compare the output to. Should be the same shape as input. + * @return The score array. The first dimension should be the batch, so it has shape {@code [batchSize, ...]}. + */ + protected SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable labels){ + throw new UnsupportedOperationException("defineLossArray not implemented for " + this.getClass().getSimpleName()); + } + + /** + * + * Define the loss function for a {@link SameDiff} instance. Should return a scalar.
+ * + * If average is true, should be the batchwise average, otherwise the sum.
+ * + * The default implementation of this method calls {@link #defineLossArray(SameDiff, SDVariable, SDVariable)} and then averages it if necessary. + * If you are not using {@link SDLoss} ops often it is easier to implement {@link #defineLossArray(SameDiff, SDVariable, SDVariable)} when extending this class. + * However, if you are, you must implement this method instead. See {@link BaseLossFunction}.
+ * + * Note that using a {@link org.nd4j.autodiff.samediff.ops.SDLoss} function with {@link org.nd4j.autodiff.loss.LossReduce} MEAN_BY_NONZERO_WEIGHT_COUNT + * will result in the loss values for DL4J and SameDiff being slightly different, but is as close as you can get. + * DL4J gets the loss with average=false and then averages it itself, while SameDiff will work with what you pass it. + * + * @see BaseLossFunction + * @param sameDiff The {@link SameDiff} instance + * @param input The input to the loss function, typically the output of the previous layer. + * @param labels The labels to compare the output to. Should be the same shape as input. + * @param average Whether to average the loss per example. + * @return The scalar score (loss function value). + */ + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable labels, boolean average) { + try{ + return reduceLossArray(defineLossArray(sameDiff, input, labels), labels, average); + } catch (UnsupportedOperationException e){ + throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName(), e); + } + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/ILossFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/ILossFunction.java index 7af044937c55..b27a3ac06658 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/ILossFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/ILossFunction.java @@ -17,6 +17,9 @@ package org.nd4j.linalg.lossfunctions; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -26,7 +29,7 @@ import java.io.Serializable; /** - * Interface for loss functions + * Interface for loss functions. Custom loss functions should probably extend {@link BaseLossFunction} instead of this interface. */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", defaultImpl = LegacyILossFunctionDeserializerHelper.class) @@ -78,6 +81,27 @@ public interface ILossFunction extends Serializable { Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average); + /** + * Define the loss function for a {@link SameDiff} instance. Should return a scalar.

+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} + * like the default implementation in {@link BaseLossFunction}.
+ * + * If average is true, should be the batchwise average, otherwise the sum.
+ * + * Note that using a {@link org.nd4j.autodiff.samediff.ops.SDLoss} function with {@link org.nd4j.autodiff.loss.LossReduce} MEAN_BY_NONZERO_WEIGHT_COUNT + * will result in the loss values for DL4J and SameDiff being slightly different. + * DL4J gets the loss with average=false and then averages it itself, while SameDiff will work with what you pass it. + * + * @param sameDiff The {@link SameDiff} instance + * @param input The input to the loss function, typically the output of the previous layer. + * @param labels The labels to compare the output to. Should be the same shape as input. + * @param average Whether to average the loss per example. + * @return The scalar score (loss function value). + */ + SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable labels, boolean average); + + //TODO defineLossArray method. Or make defineLoss take LossReduce + /** * The opName of this function * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java index 9ac7e86cde38..85bfc8e43a31 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossUtil.java @@ -16,6 +16,9 @@ package org.nd4j.linalg.lossfunctions; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; @@ -53,4 +56,16 @@ public static void applyMask(INDArray to, INDArray mask) { + Arrays.toString(mask.shape()) + ", output shape: " + Arrays.toString(to.shape())); } } + + public static SDVariable multiplyWeight(@NonNull SDVariable loss, INDArray weight){ + if(weight == null){ + return loss; + } else { + return loss.mul(loss.getSameDiff().constant(weight.castTo(loss.dataType()))); + } + } + + public static SDVariable batchAverage(@NonNull SDVariable loss){ + return loss.sum().div(loss.getSameDiff().sizeAt(loss, 0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java index e78376b5ffd0..d4eb4002d760 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java @@ -29,42 +29,82 @@ * SameDiff loss function. * * This class can be extended to create Deeplearning4j loss functions by defining one single method only: - * {@link #defineLoss(SameDiff, SDVariable, SDVariable)}. This method is used to define the loss function on a + * {@link #defineLoss(SameDiff, SDVariable, SDVariable, boolean)}. This method is used to define the loss function on a * per example basis - i.e., the output should be an array with shape [minibatch].
*
* For example, the mean squared error (MSE) loss function can be defined using:
* {@code return labels.squaredDifference(layerInput).mean(1);} * */ -public abstract class SameDiffLoss implements ILossFunction { - protected transient SameDiff sd; - protected transient SDVariable scorePerExampleVariable; +public abstract class SameDiffLoss extends BaseLossFunction { + protected transient SameDiff sumSD; + protected transient SameDiff averageSD; + protected transient SameDiff arraySD; + protected static final String LOSS_VAR_NAME = "loss"; protected SameDiffLoss() { } - /** - * Define the loss function.
- * NOTE: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i] - * is the score for the ith minibatch - * - * @param sd SameDiff instance to define the loss on - * @param layerInput Input to the SameDiff loss function - * @param labels Labels placeholder - * @return The score on a per example basis (SDVariable with shape [minibatch]) - */ - public abstract SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels); +// /** +// * Define the loss function.
+// * NOTE: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i] +// * is the score for the ith minibatch +// * +// * @param sameDiff SameDiff instance to define the loss on +// * @param layerInput Input to the SameDiff loss function +// * @param labels Labels placeholder +// * @return The score on a per example basis (SDVariable with shape [minibatch]) +// */ +// public abstract SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable labels); + + protected void makeArraySDIfNeeded(DataType dataType){ + if(arraySD != null) + return; + + SameDiff sd = SameDiff.create(); + SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1); + SDVariable labels = sd.placeHolder("labels", dataType, -1); + SDVariable loss = this.defineLossArray(sd, layerInput, labels); + loss.rename(LOSS_VAR_NAME); + loss.markAsLoss(); + sd.createGradFunction("layerInput"); + + arraySD = sd; + } + + protected void createSameDiffInstance(DataType dataType, boolean average){ + SameDiff sd; + if(average) { + averageSD = SameDiff.create(); + sd = averageSD; + } else { + sumSD = SameDiff.create(); + sd = sumSD; + } - protected void createSameDiffInstance(DataType dataType){ - sd = SameDiff.create(); SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1); SDVariable labels = sd.placeHolder("labels", dataType, -1); - scorePerExampleVariable = this.defineLoss(sd, layerInput, labels); - scorePerExampleVariable.markAsLoss(); + SDVariable loss = this.defineLoss(sd, layerInput, labels, average); + loss.rename(LOSS_VAR_NAME); + loss.markAsLoss(); sd.createGradFunction("layerInput"); } + protected void createSameDiffInstanceIfRequired(DataType dataType, boolean average){ + if(average){ + if(averageSD == null) + createSameDiffInstance(dataType, average); + } else { + if(sumSD == null) + createSameDiffInstance(dataType, average); + } + } + + protected SameDiff getSameDiffInstance(boolean average){ + return average ? averageSD : sumSD; + } + /** * Compute the score (loss function value) for the given inputs. * @@ -76,17 +116,19 @@ protected void createSameDiffInstance(DataType dataType){ */ @Override public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { - if(sd == null){ - createSameDiffInstance(preOutput.dataType()); - } + createSameDiffInstanceIfRequired(preOutput.dataType(), average); - INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask); + Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1)); - double score = scoreArr.sumNumber().doubleValue(); - if (average) { - score /= scoreArr.size(0); - } - return score; + INDArray output = activationFn.getActivation(preOutput.dup(), true); + + Map inputs = new HashMap<>(); + inputs.put("labels", labels); + inputs.put("layerInput", output); + + INDArray score = getSameDiffInstance(average).outputSingle(inputs, LOSS_VAR_NAME); + + return score.sumNumber().doubleValue(); } @@ -101,24 +143,16 @@ public double computeScore(INDArray labels, INDArray preOutput, IActivation acti */ @Override public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { - if(sd == null){ - createSameDiffInstance(preOutput.dataType()); - } - + makeArraySDIfNeeded(preOutput.dataType()); Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1)); INDArray output = activationFn.getActivation(preOutput.dup(), true); - Map m = new HashMap<>(); - m.put("labels", labels); - m.put("layerInput", output); - - INDArray scoreArr = sd.outputSingle(m, scorePerExampleVariable.name()); + Map inputs = new HashMap<>(); + inputs.put("labels", labels); + inputs.put("layerInput", output); - if (mask != null) { - LossUtil.applyMask(scoreArr, mask); - } - return scoreArr; + return arraySD.outputSingle(inputs, LOSS_VAR_NAME); } @@ -133,9 +167,7 @@ public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivati */ @Override public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { - if(sd == null){ - createSameDiffInstance(preOutput.dataType()); - } + createSameDiffInstanceIfRequired(preOutput.dataType(), false); Map m = new HashMap<>(); @@ -143,14 +175,15 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation m.put("labels", labels); m.put("layerInput", output); - Map grads = sd.calculateGradients(m, "layerInput"); + Map grads = getSameDiffInstance(false).calculateGradients(m, "layerInput"); INDArray gradAtActivationOutput = grads.get("layerInput"); INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst(); - if (mask != null) { - LossUtil.applyMask(gradAtInput, mask); - } + //TODO no mask application in forward pass yet +// if (mask != null) { +// LossUtil.applyMask(gradAtInput, mask); +// } return gradAtInput; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java index f32cb7a2dcf7..d5b1fba732a6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java @@ -18,7 +18,10 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; @@ -28,7 +31,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -50,7 +53,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter @Setter -public class LossBinaryXENT implements ILossFunction { +public class LossBinaryXENT extends BaseLossFunction { public static final double DEFAULT_CLIPPING_EPSILON = 1e-5; @JsonSerialize(using = NDArrayTextSerializer.class) @@ -237,6 +240,23 @@ public Pair computeGradientAndScore(INDArray labels, INDArray computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + SDVariable scoreArr; + if(input.getCreator().opName().equals("softmax")){ + scoreArr = sameDiff.math.log(input).mul(labels); + } else { + input = sameDiff.math.clipByValue(input, clipEps, 1-clipEps); + scoreArr = sameDiff.math.log(input).mul(labels); + + SDVariable secondTerm = sameDiff.math.log(input.rsub(1)).mul(labels.rsub(1)); + + scoreArr = scoreArr.add(secondTerm); + } + return LossUtil.multiplyWeight(scoreArr.mul(-1), weights); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java index 1f8ac194237e..24cb4df87114 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossCosineProximity.java @@ -17,11 +17,14 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -31,7 +34,7 @@ * Created by susaneraly on 9/9/16. */ @EqualsAndHashCode -public class LossCosineProximity implements ILossFunction { +public class LossCosineProximity extends BaseLossFunction { /** * @@ -138,6 +141,12 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return sameDiff.math.cosineSimilarity(labels, input, 1).neg().reshape(-1, 1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java index dad0b9a7a854..40df39271952 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFMeasure.java @@ -18,10 +18,14 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.common.primitives.Pair; import org.nd4j.shade.jackson.annotation.JsonProperty; @@ -53,7 +57,7 @@ */ @Getter @EqualsAndHashCode -public class LossFMeasure implements ILossFunction { +public class LossFMeasure extends BaseLossFunction { public static final double DEFAULT_BETA = 1.0; @@ -180,6 +184,52 @@ public Pair computeGradientAndScore(INDArray labels, INDArray computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels, boolean average) { + long n = labels.placeholderShape()[1]; + if (n != 1 && n != 2) { + throw new UnsupportedOperationException( + "For binary classification: expect output size of 1 or 2. Got: " + n); + } + + //First: determine positives and negatives + SDVariable isPositiveLabel; + SDVariable isNegativeLabel; + SDVariable pClass0; + SDVariable pClass1; + if (n == 1) { + isPositiveLabel = labels; + isNegativeLabel = isPositiveLabel.rsub(1.0); + pClass0 = input.rsub(1.0); + pClass1 = input; + } else { + isPositiveLabel = labels.get(SDIndex.all(), SDIndex.point(1)); + isNegativeLabel = labels.get(SDIndex.all(), SDIndex.point(0)); + pClass0 = input.get(SDIndex.all(), SDIndex.point(0)); + pClass1 = input.get(SDIndex.all(), SDIndex.point(1)); + } + + SDVariable tp = isPositiveLabel.mul(pClass1).sum(); + SDVariable fp = isNegativeLabel.mul(pClass1).sum(); + SDVariable fn = isPositiveLabel.mul(pClass0).sum(); + + SDVariable numerator = tp.mul(1.0 + beta * beta); + SDVariable denominator = tp.mul(1.0 + beta * beta).add(fn.mul(beta * beta)).add(fp); + + SDVariable eps = sameDiff.constant(Nd4j.EPS_THRESHOLD); + numerator = sameDiff.math.max(sameDiff.math.abs(numerator), eps).mul(sameDiff.math.sign(numerator)); + denominator = sameDiff.math.max(sameDiff.math.abs(denominator), eps).mul(sameDiff.math.sign(denominator)); + + // have to use labels to get batch size + SDVariable out = numerator.div(denominator).rsub(1).sum(); + + if(average) + return out.div(sameDiff.sizeAt(labels, 0)); + else + return out; + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java index 499a42133b0b..0e63d5bc2a86 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossHinge.java @@ -17,12 +17,16 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.common.primitives.Pair; @@ -30,7 +34,7 @@ * Created by susaneraly on 8/15/16. */ @EqualsAndHashCode -public class LossHinge implements ILossFunction { +public class LossHinge extends BaseLossFunction { public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { if(!labels.equalShapes(preOutput)){ @@ -115,6 +119,12 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return sameDiff.math.max(input.mul(labels).rsub(1), sameDiff.constant(Nd4j.scalar(input.dataType(), 0.0))).sum(true, 1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java index ffcde7302ab4..383a49b9ff9b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossKLD.java @@ -18,11 +18,14 @@ import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -33,7 +36,7 @@ * @author Susan Eraly */ @EqualsAndHashCode -public class LossKLD implements ILossFunction { +public class LossKLD extends BaseLossFunction { private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { if(!labels.equalShapes(preOutput)){ @@ -111,6 +114,14 @@ public Pair computeGradientAndScore(INDArray labels, INDArray computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + input = sameDiff.math.clipByValue(input, Nd4j.EPS_THRESHOLD, 1); + labels = sameDiff.math.clipByValue(labels, Nd4j.EPS_THRESHOLD, 1); + + return sameDiff.math.log(input.rdiv(labels)).mul(labels); + } /** * The opName of this function diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java index 56ffaee086f9..d4fcfc94d754 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL1.java @@ -18,12 +18,15 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -42,7 +45,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter -public class LossL1 implements ILossFunction { +public class LossL1 extends BaseLossFunction { @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) @@ -152,6 +155,16 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + protected SDVariable defineFullLossArray(SameDiff sameDiff, SDVariable input, SDVariable labels){ + return LossUtil.multiplyWeight(sameDiff.math.abs(input.sub(labels)), weights); + } + + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return defineFullLossArray(sameDiff, input, labels).sum(true, 1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java index b6d6bb761b31..f6419fde58cb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossL2.java @@ -18,10 +18,13 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.common.primitives.Pair; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; @@ -41,7 +44,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter -public class LossL2 implements ILossFunction { +public class LossL2 extends BaseLossFunction { @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) @@ -151,6 +154,15 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + protected SDVariable defineFullLossArray(SameDiff sameDiff, SDVariable input, SDVariable labels){ + return LossUtil.multiplyWeight(labels.squaredDifference(input), weights); + } + + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return defineFullLossArray(sameDiff, input, labels).sum(true, 1); + } /** * The opName of this function diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAE.java index a232e0311bb3..869ab5363b52 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAE.java @@ -17,8 +17,13 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.LossUtil; /** * Mean absolute error loss function: L = 1/N sum_i abs(predicted_i - actual_i) @@ -67,6 +72,12 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation return gradients; } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return defineFullLossArray(sameDiff, input, labels).mean(true, 1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java index 03c9a461fb19..5259b817ad78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMAPE.java @@ -18,13 +18,17 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.same.Abs; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -40,7 +44,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter -public class LossMAPE implements ILossFunction { +public class LossMAPE extends BaseLossFunction { @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) @@ -154,6 +158,12 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return LossUtil.multiplyWeight(sameDiff.math.abs(input.rsub(labels).div(labels)).mul(100).div(sameDiff.sizeAt(labels, 1).castTo(input.dataType())), weights); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java index 2863f282ea12..a4889d86c069 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java @@ -20,14 +20,17 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -52,7 +55,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter @Setter -public class LossMCXENT implements ILossFunction { +public class LossMCXENT extends BaseLossFunction { private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10; @JsonSerialize(using = NDArrayTextSerializer.class) @@ -202,6 +205,14 @@ public Pair computeGradientAndScore(INDArray labels, INDArray computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable labels) { + if(input.getCreator().opName().equals("softmax") && softmaxClipEps > 0.0){ + input = sameDiff.math.clipByValue(input, softmaxClipEps, 1.0-softmaxClipEps); + } + return LossUtil.multiplyWeight(sameDiff.math.log(input).mul(labels).neg(), weights); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java index bb64bb777fae..be48c255b1d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSE.java @@ -17,8 +17,13 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.shade.jackson.annotation.JsonProperty; /** @@ -64,6 +69,12 @@ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation return gradients.divi(labels.size(1)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return defineFullLossArray(sameDiff, input, labels).mean(true, 1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java index cf459e2aa69b..7c5847cb0b04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMSLE.java @@ -18,10 +18,14 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDIndex; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -39,7 +43,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter -public class LossMSLE implements ILossFunction { +public class LossMSLE extends BaseLossFunction { @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) @@ -152,6 +156,13 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + SDVariable score = sameDiff.math.log(input.add(1.0).div(labels.add(1.0))); + return LossUtil.multiplyWeight(score.mul(score).mean(true, 1), weights); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java index b58be0e8be5c..a143cea72f12 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; @@ -64,7 +65,7 @@ */ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) -public class LossMixtureDensity implements ILossFunction { +public class LossMixtureDensity extends BaseLossFunction { private int mMixtures; private int mLabelWidth; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java index 199ff4b4bdcd..ca0897491296 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMultiLabel.java @@ -21,6 +21,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; @@ -56,7 +57,7 @@ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @Getter -public class LossMultiLabel implements ILossFunction { +public class LossMultiLabel extends BaseLossFunction { public LossMultiLabel() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java index 54ec246afd0c..28b511b4c3be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossPoisson.java @@ -17,10 +17,13 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; @@ -29,7 +32,7 @@ * Created by susaneraly on 9/9/16. */ @EqualsAndHashCode -public class LossPoisson implements ILossFunction { +public class LossPoisson extends BaseLossFunction { public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { if(!labels.equalShapes(preOutput)){ @@ -107,6 +110,12 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return sameDiff.math.log(input).mul(labels).rsub(input).sum(true,1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java index 43e76e511d18..aedfe90591e5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java @@ -19,7 +19,10 @@ import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -115,6 +118,11 @@ private INDArray toOneHot(INDArray labels, INDArray preOutput){ return oneHotLabels; } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + throw new UnsupportedOperationException("toSameDiff for LossSparseMCXENT is not yet supported."); + } @Override public String toString() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java index 8f4874d9a8b2..9c2d94f3dfbf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSquaredHinge.java @@ -17,12 +17,16 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.common.primitives.Pair; @@ -30,7 +34,7 @@ * Created by susaneraly on 9/9/16. */ @EqualsAndHashCode -public class LossSquaredHinge implements ILossFunction { +public class LossSquaredHinge extends BaseLossFunction { public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { if(!labels.equalShapes(preOutput)){ @@ -113,6 +117,13 @@ public Pair computeGradientAndScore(INDArray labels, computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + SDVariable hinge = sameDiff.math.max(input.mul(labels).rsub(1), sameDiff.constant(Nd4j.scalar(input.dataType(), 0.0))); + return hinge.mul(hinge).sum(true, 1); + } + /** * The opName of this function * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java index 5a837590ece4..357db59bb0eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossWasserstein.java @@ -17,11 +17,14 @@ package org.nd4j.linalg.lossfunctions.impl; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.common.primitives.Pair; @@ -38,7 +41,7 @@ * @author Ryan Nett */ @EqualsAndHashCode(callSuper = false) -public class LossWasserstein implements ILossFunction { +public class LossWasserstein extends BaseLossFunction { private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask){ if(!labels.equalShapes(preOutput)){ @@ -102,6 +105,13 @@ public Pair computeGradientAndScore(INDArray labels, INDArray return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average), computeGradient(labels, preOutput, activationFn, mask)); } + + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + return labels.mul(input).mean(true, 1); + } + @Override public String name() { return toString(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/schedule/RampSchedule.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/schedule/RampSchedule.java index 1a8c084e63c6..accd999138cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/schedule/RampSchedule.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/schedule/RampSchedule.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.schedule; +import lombok.Data; + /** * A "Wrapper" schedule that ramps up from {@code 1/numIter * baseLR} to {@code baseLR} over numIter iterations. * The base learning rate is determined by the underlying ISchedule, as a function of time. @@ -7,6 +9,7 @@ * * @author Alex Black */ +@Data public class RampSchedule implements ISchedule { protected final ISchedule baseSchedule; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 5e12f1dfdc92..0db3b66124f4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -294,7 +294,7 @@ public INDArray exec(ReduceOp op, OpContext oc) { if (x.isVector() && x.length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && y == null) - return op.noOp(); + return op.noOp(x, z); /** * This is the result array. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java index 9822962c4655..26bbf3b69b9e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.lossfunctions; import org.junit.Test; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; @@ -131,6 +133,21 @@ public void testWeightedLossFunctionDTypes(){ //Check backward lf.computeGradient(l, preOut, new ActivationSoftmax(), null); + + INDArray scoreArray = lf.computeScoreArray(l, preOut, new ActivationSoftmax(), null); + + // check SameDiff conversion + try { + SameDiff sameDiff = SameDiff.create(); + SDVariable input = sameDiff.nn.softmax(sameDiff.constant(preOut)); + SDVariable labels = sameDiff.constant(l); + SDVariable loss = lf.defineLoss(sameDiff, input, labels, false).sum(); + + assertTrue("SameDiff loss doesn't match INDArray loss", scoreArray.sum().equalsWithEps(loss.eval(), 1e-5)); + + } catch (UnsupportedOperationException e){ + + } } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java index ae4d45fb2ed5..1c666fb7890c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java @@ -17,9 +17,12 @@ package org.deeplearning4j.rl4j.network.ac; import lombok.EqualsAndHashCode; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.BaseLossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.nd4j.linalg.ops.transforms.Transforms; @@ -39,7 +42,7 @@ */ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) -public class ActorCriticLoss implements ILossFunction { +public class ActorCriticLoss extends BaseLossFunction { public static final double BETA = 0.01; @@ -90,6 +93,14 @@ public Pair computeGradientAndScore(INDArray labels, INDArray computeGradient(labels, preOutput, activationFn, mask)); } + @Override + public SDVariable defineLossArray(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable labels) { + SDVariable log = sameDiff.math.log(input); + return log.mul(labels) + .sub(input.mul(log).mul(sameDiff.constant(BETA))); + } + @Override public String toString() { return "ActorCriticLoss()";