diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
index b74df2d2c8a3..1f7ce29c038a 100644
--- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
+++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java
@@ -23,7 +23,9 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
+import org.junit.rules.DisableOnDebug;
import org.junit.rules.TestName;
+import org.junit.rules.TestRule;
import org.junit.rules.Timeout;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JSystemProperties;
@@ -48,7 +50,7 @@ public abstract class BaseDL4JTest {
@Rule
public TestName name = new TestName();
@Rule
- public Timeout timeout = Timeout.millis(getTimeoutMilliseconds());
+ public TestRule timeout = new DisableOnDebug(Timeout.millis(getTimeoutMilliseconds()));
protected long startTime;
protected int threadCountBefore;
diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml
index 90c88d4c3d86..5608c066d29e 100644
--- a/deeplearning4j/deeplearning4j-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-core/pom.xml
@@ -22,6 +22,7 @@
deeplearning4j-parent
1.0.0-SNAPSHOT
+
@@ -166,6 +167,46 @@
+
+
+
+ org.apache.maven.wagon
+ wagon-http
+ 2.9
+
+
+ org.kuali.maven.wagons
+ maven-s3-wagon
+ 1.2.1
+
+
+
+
+
+
+ maven-surefire-plugin
+ true
+
+ true
+ false
+ -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g
+
+
+ *.java
+ **/*.java
+
+
+
+ listener
+ org.deeplearning4j.samediff.ToSameDiffTests
+
+
+
+
+
+
+
+
test-nd4j-native
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java
index c004c5c0de1b..1d0147dedc3c 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java
@@ -16,14 +16,42 @@
package org.deeplearning4j;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import java.io.BufferedOutputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.OutputStream;
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import lombok.NonNull;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
+import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
+import org.deeplearning4j.nn.layers.BaseOutputLayer;
+import org.deeplearning4j.nn.layers.LossLayer;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer;
import org.deeplearning4j.nn.layers.normalization.BatchNormalization;
@@ -31,7 +59,10 @@
import org.deeplearning4j.nn.layers.recurrent.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
+import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
@@ -40,14 +71,7 @@
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
-
-import java.io.*;
-import java.lang.reflect.Field;
-import java.util.List;
-import java.util.Random;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
public class TestUtils {
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java
index 709d889017cc..19a0b8f184e2 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java
@@ -28,6 +28,8 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.TestToSameDiff;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -106,6 +108,7 @@ public void testSelfAttentionLayer() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
}
}
}
@@ -167,6 +170,7 @@ public void testLearnedSelfAttentionLayer() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
}
}
}
@@ -322,6 +326,8 @@ public void testRecurrentAttentionLayer() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
+
+ ToSameDiffTests.testToSameDiff(net, in, labels);
}
}
}
@@ -385,6 +391,7 @@ public void testAttentionVertex() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in})
.labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null).subset(true).maxPerParam(100));
assertTrue(name, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
}
}
}
@@ -447,6 +454,7 @@ public void testAttentionVertexSameInput() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{in})
.labels(new INDArray[]{labels}).inputMask(inMask != null ? new INDArray[]{inMask} : null));
assertTrue(name, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java
index 081abd45da12..e3e37093a0da 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java
@@ -30,11 +30,11 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
@@ -42,8 +42,6 @@
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
-import org.nd4j.linalg.profiler.OpProfiler;
-import org.nd4j.linalg.profiler.ProfilerConfig;
import java.util.Arrays;
import java.util.HashSet;
@@ -104,6 +102,7 @@ public void testGradient2dSimple() {
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -149,6 +148,7 @@ public void testGradientCnnSimple() {
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -251,6 +251,7 @@ public void testGradientBNWithCNNandSubsampling() {
.labels(labels).excludeParams(excludeParams).subset(true).maxPerParam(25)); //Most params are in output layer, only these should be skipped with this threshold
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -355,6 +356,7 @@ public void testGradientDense() {
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -399,6 +401,7 @@ public void testGradient2dFixedGammaBeta() {
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -444,6 +447,7 @@ public void testGradientCnnFixedGammaBeta() {
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -489,6 +493,7 @@ public void testBatchNormCompGraphSimple() {
assertTrue(gradOK);
TestUtils.testModelSerialization(net);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
}
}
@@ -587,6 +592,7 @@ public void testGradientBNWithCNNandSubsamplingCompGraph() {
assertTrue(gradOK);
TestUtils.testModelSerialization(net);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java
index 06fe0cf350a6..834d208d83e9 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java
@@ -27,8 +27,8 @@
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.deeplearning4j.util.Convolution1DUtils;
-import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -122,6 +122,7 @@ public void testCnn1DWithLocallyConnected1D() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
@@ -202,6 +203,7 @@ public void testCnn1DWithCropping1D() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -285,6 +287,7 @@ public void testCnn1DWithZeroPadding1D() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -362,6 +365,7 @@ public void testCnn1DWithSubsampling1D() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -423,6 +427,7 @@ public void testCnn1dWithMasking(){
.labels(label).inputMask(fm));
assertTrue(s, gradOK);
+ ToSameDiffTests.testToSameDiff(net, f, label);
TestUtils.testModelSerialization(net);
//TODO also check that masked step values don't impact forward pass, score or gradients
@@ -518,6 +523,7 @@ public void testCnn1Causal() {
.labels(label).inputMask(fm));
assertTrue(s, gradOK);
+ ToSameDiffTests.testToSameDiff(net, f, label);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
index 30cc783da458..f51eff628105 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java
@@ -29,9 +29,9 @@
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
-import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@@ -159,6 +159,7 @@ public void testCnn3DPlain() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -262,6 +263,7 @@ public void testCnn3DZeroPadding() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
@@ -352,6 +354,7 @@ public void testCnn3DPooling() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -442,6 +445,7 @@ public void testCnn3DUpsampling() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -541,6 +545,7 @@ public void testCnn3DCropping() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
@@ -632,6 +637,7 @@ public void testDeconv3d() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java
index c303cc594498..1b617a2f7ab2 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java
@@ -16,7 +16,6 @@
package org.deeplearning4j.gradientcheck;
-import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
@@ -33,6 +32,7 @@
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -153,6 +153,7 @@ public void testGradientCNNMLN() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -247,6 +248,8 @@ public void testGradientCNNL1L2MLN() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
+
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -309,6 +312,7 @@ public void testCnnWithSpaceToDepth() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -378,6 +382,7 @@ public void testCnnWithSpaceToBatch() {
.labels(new INDArray[]{labels}));
assertTrue(msg + " - compgraph", gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -438,6 +443,7 @@ public void testCnnWithUpsampling() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -509,6 +515,7 @@ public void testCnnWithSubsampling() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -578,6 +585,7 @@ public void testCnnWithSubsamplingV2() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -638,6 +646,8 @@ public void testCnnLocallyConnected2D() {
assertTrue(msg, gradOK);
+ //TODO existing define method requires offline shape inference
+ // ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -705,6 +715,7 @@ public void testCnnMultiLayer() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -770,6 +781,7 @@ public void testCnnSamePaddingMode() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -837,6 +849,7 @@ public void testCnnSamePaddingModeStrided() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -920,6 +933,7 @@ public void testCnnZeroPaddingLayer() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -995,6 +1009,7 @@ public void testDeconvolution2D() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -1068,6 +1083,7 @@ public void testSeparableConv2D() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -1152,6 +1168,7 @@ public void testCnnDilated() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -1227,6 +1244,7 @@ public void testCropping2DLayer() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -1298,6 +1316,7 @@ public void testDepthwiseConv2D() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java
index e604c594dc1a..0ff5ce3e38ac 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java
@@ -31,6 +31,7 @@
import org.deeplearning4j.nn.conf.layers.PrimaryCapsules;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
@@ -39,8 +40,6 @@
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
-import java.util.Random;
-
public class CapsnetGradientCheckTest extends BaseDL4JTest {
@Override
@@ -114,6 +113,7 @@ public void testCapsNet() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java
index c4f9d2843af0..316a79912dfc 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java
@@ -31,6 +31,7 @@
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -141,6 +142,7 @@ public void testDropoutGradient() {
false, -1, null, 12345); //Last arg: ensures RNG is reset at each iter... otherwise will fail due to randomness!
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, f, l);
TestUtils.testModelSerialization(mln);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java
index 742052a42b54..e2bea569a386 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java
@@ -27,6 +27,7 @@
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -106,6 +107,7 @@ public void testRNNGlobalPoolingBasicMultiLayer() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -165,6 +167,7 @@ public void testCnnGlobalPoolingBasicMultiLayer() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -225,6 +228,7 @@ public void testLSTMWithMasking() {
.labels(labels).inputMask(featuresMask));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -308,6 +312,7 @@ public void testCnnGlobalPoolingMasking() {
.labels(labels).inputMask(inputMask));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java
index 2c6f8843e375..a8022d0753a4 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java
@@ -27,11 +27,13 @@
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@@ -136,6 +138,7 @@ public void testMinibatchApplication() {
String msg = "testMinibatchApplication() - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst;
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, ds.getFeatures(), ds.getLabels());
TestUtils.testModelSerialization(mln);
}
@@ -216,6 +219,7 @@ public void testGradientMLP2LayerIrisSimple() {
String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst;
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -311,6 +315,8 @@ public void testGradientMLP2LayerIrisL1L2Simple() {
+ ", outputActivation=" + outputActivation + ", doLearningFirst="
+ doLearningFirst + ", l2=" + l2 + ", l1=" + l1;
assertTrue(msg, gradOK);
+
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -394,6 +400,7 @@ public void testEmbeddingLayerSimple() {
String msg = "testEmbeddingLayerSimple";
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
@@ -446,6 +453,7 @@ public void testAutoEncoder() {
.activation(afn).build())
.layer(1, new OutputLayer.Builder(lf).nIn(3).nOut(3)
.activation(outputActivation).build())
+ .setInputType(InputType.inferInputType(input))
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -482,6 +490,7 @@ public void testAutoEncoder() {
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -553,6 +562,7 @@ public void elementWiseMultiplicationLayerTest(){
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(netGraph);
+ ToSameDiffTests.testToSameDiff(netGraph, features, labels);
}
}
@@ -579,6 +589,7 @@ public void testEmbeddingSequenceLayer(){
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
+ .setInputType(InputType.recurrent(3, 6, RNNFormat.NCW))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@@ -606,6 +617,7 @@ public void testEmbeddingSequenceLayer(){
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, label);
TestUtils.testModelSerialization(net);
@@ -705,6 +717,7 @@ public void testGradientWeightDecay() {
+ ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1;
assertTrue(msg, gradOK1);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -789,6 +802,7 @@ public void testGradientMLP2LayerIrisLayerNorm() {
String msg = "testGradMLP2LayerIrisSimple() - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", layerNorm=" + layerNorm;
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java
index ac3c3deea8ef..1c5a46431988 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java
@@ -32,12 +32,10 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
-import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
-import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
-import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -48,7 +46,6 @@
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
-import java.util.Arrays;
import java.util.Map;
import java.util.Random;
@@ -116,6 +113,7 @@ public void testBasicIris() {
String msg = "testBasicIris()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
@Test
@@ -167,6 +165,7 @@ public void testBasicIrisWithMerging() {
String msg = "testBasicIrisWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
@Test
@@ -224,6 +223,7 @@ public void testBasicIrisWithElementWiseNode() {
String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
}
@@ -284,6 +284,7 @@ public void testBasicIrisWithElementWiseNodeInputSizeGreaterThanTwo() {
String msg = "testBasicIrisWithElementWiseVertex(op=" + op + ")";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
}
@@ -331,6 +332,7 @@ public void testElementWiseVertexBroadcast(){
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, in, labels);
}
}
}
@@ -383,6 +385,7 @@ public void testCnnDepthMerge() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{input}, new INDArray[]{labels}, new InputType[]{InputType.convolutional(6, 6, 2, format)});
}
}
@@ -443,6 +446,7 @@ public void testRNNWithMerging() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
}
@@ -480,6 +484,7 @@ public void testLSTMWithSubset() {
String msg = "testLSTMWithSubset()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
@Test
@@ -528,6 +533,7 @@ public void testLSTMWithLastTimeStepVertex() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
@Test
@@ -577,6 +583,7 @@ public void testLSTMWithDuplicateToTimeSeries() {
String msg = "testLSTMWithDuplicateToTimeSeries()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{input1, input2}, new INDArray[]{labels});
}
@Test
@@ -636,6 +643,7 @@ public void testLSTMWithReverseTimeSeriesVertex() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, labels);
}
@Test
@@ -679,6 +687,7 @@ public void testMultipleInputsLayer() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, inputs, new INDArray[]{out});
}
}
@@ -719,6 +728,7 @@ public void testMultipleOutputsLayer() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, out);
}
}
@@ -765,6 +775,7 @@ public void testMultipleOutputsMergeVertex() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, new INDArray[]{out});
}
}
@@ -816,6 +827,7 @@ public void testMultipleOutputsMergeCnn() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, input, out);
}
}
@@ -885,6 +897,7 @@ public void testBasicIrisTripletStackingL2Loss() {
String msg = "testBasicIrisTripletStackingL2Loss()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{pos, anc, neg}, new INDArray[]{labels});
}
@@ -945,6 +958,7 @@ public void testBasicCenterLoss() {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, example, labels);
}
}
}
@@ -1009,6 +1023,7 @@ public void testCnnPoolCenterLoss() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, example, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, example, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -1059,6 +1074,7 @@ public void testBasicL2() {
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels});
}
}
@@ -1117,6 +1133,7 @@ public void testBasicStackUnstack() {
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2});
}
}
@@ -1175,6 +1192,7 @@ public void testBasicStackUnstackDebug() {
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2});
}
}
@@ -1240,6 +1258,7 @@ public void testBasicStackUnstackVariableLengthTS() {
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2});
}
}
@@ -1296,6 +1315,7 @@ public void testBasicTwoOutputs() {
.labels(new INDArray[]{labels1, labels2}));
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, new INDArray[]{in1, in2}, new INDArray[]{labels1, labels2});
}
}
@@ -1339,6 +1359,7 @@ public void testL2NormalizeVertex2d() {
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, in1, labels1);
}
}
@@ -1388,6 +1409,7 @@ public void testL2NormalizeVertex4d() {
assertTrue(testName, gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, in1, labels1);
}
}
@@ -1427,5 +1449,6 @@ public void testGraphEmbeddingLayerSimple() {
String msg = "testGraphEmbeddingLayerSimple";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(cg);
+ ToSameDiffTests.testToSameDiff(cg, input, labels);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java
index 09afd6c2f4c7..613b5ab5944b 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java
@@ -28,6 +28,7 @@
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -137,6 +138,7 @@ public void gradientCheckMaskingOutputSimple() {
String msg = "gradientCheckMaskingOutputSimple() - timeSeriesLength=" + timeSeriesLength
+ ", miniBatchSize=" + 1;
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -186,6 +188,7 @@ public void testBidirectionalLSTMMasking() {
.labels(labels).inputMask(mask).labelMask(mask).subset(true).maxPerParam(12));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -267,6 +270,7 @@ public void testPerOutputMaskingMLP() {
.labels(labels).labelMask(labelMask));
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, features, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -384,6 +388,7 @@ public void testPerOutputMaskingRnn() {
assertTrue(msg + " (compgraph)", gradOK);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, features, labels);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java
index 5a6e003f258a..9ecc62f9e88b 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java
@@ -26,6 +26,7 @@
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -96,6 +97,7 @@ public void testGradientLRNSimple() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java
index 2f0822b80b16..80b4bc241255 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java
@@ -27,6 +27,7 @@
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -137,6 +138,7 @@ public void testLSTMBasicMultiLayer() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(testName, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -226,6 +228,7 @@ public void testGradientLSTMFull() {
.labels(labels).subset(true).maxPerParam(128));
assertTrue(testName, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -276,6 +279,7 @@ public void testGradientLSTMEdgeCases() {
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -356,6 +360,7 @@ public void testGradientGravesBidirectionalLSTMFull() {
String msg = "testGradientGravesLSTMFull() - activationFn=" + afn + ", lossFn=" + lf
+ ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1;
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -405,6 +410,7 @@ public void testGradientGravesBidirectionalLSTMEdgeCases() {
String msg = "testGradientGravesLSTMEdgeCases() - timeSeriesLength=" + timeSeriesLength[i]
+ ", miniBatchSize=" + miniBatchSize[i];
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -460,6 +466,7 @@ public void testGradientCnnFfRnn() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input)
.labels(labels).subset(true).maxPerParam(32));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
index 88ea98ca23ea..e67969c57fa8 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
@@ -30,6 +30,7 @@
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
@@ -226,6 +227,7 @@ public void lossFunctionGradientCheck() {
} else {
failed.add(testName);
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -395,6 +397,7 @@ public void lossFunctionGradientCheckLossLayer() {
}
TestUtils.testModelSerialization(net);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
}
}
@@ -703,6 +706,8 @@ public void lossFunctionWeightedGradientCheck() {
} else {
failed.add(testName);
}
+
+ ToSameDiffTests.testToSameDiff(net, input, labels);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java
index cc4e7410413d..dcdf58d53395 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java
@@ -24,6 +24,7 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -121,6 +122,7 @@ public void testGradientNoBiasDenseOutput() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -178,6 +180,7 @@ public void testGradientNoBiasRnnOutput() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -240,6 +243,7 @@ public void testGradientNoBiasEmbedding() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -306,6 +310,7 @@ public void testCnnWithSubsamplingNoBias() {
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java
index 24745fd0a1e1..222f3d497491 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java
@@ -22,8 +22,10 @@
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -146,6 +148,7 @@ public void testRnnLossLayer() {
.labels(labels).labelMask(labelMask));
assertTrue(testName, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -231,11 +234,13 @@ public void testCnnLossLayer() {
.convolutionMode(ConvolutionMode.Same)
.list()
.layer(new ConvolutionLayer.Builder().nIn(dIn).nOut(dOut).activation(Activation.TANH)
+ .kernelSize(3, 3)
.dist(new NormalDistribution(0, 1.0))
.updater(new NoOp()).build())
.layer(new CnnLossLayer.Builder(lf)
.activation(oa)
.build())
+ .setInputType(InputType.inferInputType(input))
.validateOutputLayerConfig(false).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -253,6 +258,7 @@ public void testCnnLossLayer() {
.labels(labels).labelMask(labelMask));
assertTrue(testName, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -385,6 +391,7 @@ public void testCnn3dLossLayer() {
.lossFunction(lf)
.activation(oa)
.build())
+ .setInputType(InputType.inferInputType(input))
.validateOutputLayerConfig(false).build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -402,6 +409,8 @@ public void testCnn3dLossLayer() {
.labels(labels).labelMask(labelMask));
assertTrue(testName, gradOK);
+ //TODO known loss issue due to DL4J packing dimensions into batch
+ ToSameDiffTests.testToSameDiff(mln, input, null);
TestUtils.testModelSerialization(mln);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java
index e356cce1d458..0d747e61ce15 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java
@@ -31,6 +31,7 @@
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@@ -133,6 +134,7 @@ public void testBidirectionalWrapper() {
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -211,6 +213,7 @@ public void testSimpleRnn() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -286,6 +289,7 @@ public void testLastTimeStepLayer(){
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(16));
assertTrue(name, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -350,6 +354,7 @@ public void testTimeDistributedDense() {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(labels).inputMask(inMask).subset(true).maxPerParam(16));
assertTrue(name, gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java
index b8412a8d26a5..ac3bd773f317 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java
@@ -31,6 +31,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -186,6 +187,7 @@ public void testMaskLayer() {
.input(input).labels(label).inputMask(inMask));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, label);
TestUtils.testModelSerialization(net);
}
}
@@ -226,6 +228,7 @@ public void testFrozenWithBackprop(){
.labels(labels).excludeParams(excludeParams));
assertTrue(gradOK);
+ ToSameDiffTests.testToSameDiff(net, in, labels);
TestUtils.testModelSerialization(net);
@@ -238,6 +241,7 @@ public void testFrozenWithBackprop(){
assertTrue(gradOKCG);
TestUtils.testModelSerialization(g);
+ ToSameDiffTests.testToSameDiff(g, in, labels);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java
index 3d4a6180c5bc..cf71fb831ffb 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java
@@ -21,10 +21,12 @@
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationTanH;
@@ -116,6 +118,7 @@ public void testVaeAsMLP() {
.dist(new NormalDistribution(0, 1))
.build())
+ .setInputType(InputType.inferInputType(input))
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -135,6 +138,7 @@ public void testVaeAsMLP() {
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input,
labels);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -184,6 +188,7 @@ public void testVaePretrain() {
.reconstructionDistribution(
new GaussianReconstructionDistribution(pxzAfn))
.activation(afn).build())
+ .setInputType(InputType.inferInputType(input))
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -207,6 +212,7 @@ public void testVaePretrain() {
RETURN_ON_FIRST_FAILURE, input, 12345);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, input, labels);
TestUtils.testModelSerialization(mln);
}
}
@@ -275,6 +281,7 @@ public void testVaePretrainReconstructionDistributions() {
reconstructionDistributions[i])
.activation(Activation.TANH)
.build())
+ .setInputType(InputType.inferInputType(data))
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -295,6 +302,7 @@ public void testVaePretrainReconstructionDistributions() {
data, 12345);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, data, null);
TestUtils.testModelSerialization(mln);
}
}
@@ -317,6 +325,7 @@ public void testVaePretrainMultipleSamples() {
new GaussianReconstructionDistribution(Activation.TANH))
.numSamples(numSamples).activation(Activation.TANH)
.build())
+ .setInputType(InputType.inferInputType(features))
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
@@ -337,6 +346,7 @@ public void testVaePretrainMultipleSamples() {
features, 12345);
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(mln, features, null);
TestUtils.testModelSerialization(mln);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java
index 47c040c1214e..9d4704b3aac9 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java
@@ -31,6 +31,7 @@
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
@@ -153,6 +154,7 @@ public void testYoloOutputLayer() {
.labels(labels).subset(true).maxPerParam(100));
assertTrue(msg, gradOK);
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -261,6 +263,7 @@ public void yoloGradientCheckRealData() throws Exception {
.labels(l).inputMask(null).subset(true).maxPerParam(64));
assertTrue(ok);
+ ToSameDiffTests.testToSameDiff(net, f, l);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java
index dbef14bf27f6..7c3a901b60b4 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java
@@ -24,7 +24,7 @@
public class SDLossMAE extends SameDiffLoss {
@Override
- public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
- return sd.math.abs(labels.sub(layerInput)).mean(1);
+ public SDVariable defineLossArray(SameDiff sameDiff, SDVariable layerInput, SDVariable labels) {
+ return sameDiff.math.abs(labels.sub(layerInput)).mean(1);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java
index 6edce7a499c8..5eb6b91bec4f 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java
@@ -24,7 +24,7 @@
public class SDLossMSE extends SameDiffLoss {
@Override
- public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
+ public SDVariable defineLossArray(SameDiff sameDiff, SDVariable layerInput, SDVariable labels) {
return labels.squaredDifference(layerInput).mean(1);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java
index 0db6a13572f3..e981bb115ac0 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java
@@ -37,6 +37,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -100,6 +101,7 @@ public void testLayerRecurrentConstraints() throws Exception {
assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -153,6 +155,7 @@ public void testLayerBiasConstraints() throws Exception {
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -205,6 +208,7 @@ public void testLayerWeightsConstraints() throws Exception {
assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -265,6 +269,7 @@ public void testLayerWeightsAndBiasConstraints() throws Exception {
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -326,6 +331,7 @@ public void testLayerWeightsAndBiasSeparateConstraints() throws Exception {
assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6);
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
@@ -384,6 +390,7 @@ public void testModelConstraints() throws Exception {
assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 );
}
+ ToSameDiffTests.testToSameDiff(net, input, labels);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java
index 574cb0d39c23..9f2a2cd1de63 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java
@@ -30,6 +30,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
@@ -144,7 +145,7 @@ public void testCalls(){
}
@Data
- public static class CustomDropout implements IDropout{
+ public static class CustomDropout extends BaseDropout{
private List> allCalls = new ArrayList<>();
private List> allReverseCalls = new ArrayList<>();
@@ -191,6 +192,7 @@ public void testSerialization(){
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
+ ToSameDiffTests.testToSameDiff(net, null, null);
TestUtils.testModelSerialization(net);
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java
index 77061e3987ee..208e88295a99 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/custom/MyCustomPreprocessor.java
@@ -18,8 +18,8 @@
import lombok.EqualsAndHashCode;
import org.deeplearning4j.nn.api.MaskState;
-import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@@ -28,7 +28,7 @@
* Created by Alex on 09/09/2016.
*/
@EqualsAndHashCode
-public class MyCustomPreprocessor implements InputPreProcessor {
+public class MyCustomPreprocessor extends BaseInputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
@@ -41,7 +41,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w
}
@Override
- public InputPreProcessor clone() {
+ public BaseInputPreProcessor clone() {
return new MyCustomPreprocessor();
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java
index 0ebc598bc68b..444373ce3d94 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java
@@ -32,6 +32,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -77,6 +78,7 @@ public void testWeightNoiseConfigJson() {
assertEquals(wn, ((BaseLayer) net.getLayer(2).conf().getLayer()).getWeightNoise());
TestUtils.testModelSerialization(net);
+ ToSameDiffTests.testToSameDiff(net, null, null);
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
@@ -97,6 +99,7 @@ public void testWeightNoiseConfigJson() {
assertEquals(wn, ((BaseLayer) graph.getLayer(2).conf().getLayer()).getWeightNoise());
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, Nd4j.create(1,10), Nd4j.create(1,10));
graph.fit(new DataSet(Nd4j.create(1,10), Nd4j.create(1,10)));
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
index beec5cf2042c..ecc731a85221 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
@@ -20,6 +20,7 @@
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@@ -1024,6 +1025,8 @@ public void testEmbeddingDtypes() {
logUsedClasses(net);
+ ToSameDiffTests.testToSameDiff(net, input, label);
+
//Now, test mismatched dtypes for input/labels:
for (DataType inputLabelDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
INDArray in2 = input.castTo(inputLabelDtype);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java
index b0cc17376248..b20517e316c9 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java
@@ -56,6 +56,7 @@
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.*;
import org.junit.rules.TemporaryFolder;
@@ -1471,6 +1472,8 @@ public void testZeroParamNet() throws Exception {
ComputationGraph net2 = TestUtils.testModelSerialization(net);
INDArray out2 = net2.outputSingle(ds.getFeatures());
assertEquals(out, out2);
+ // labels are wrong size, would need to be [batch, 1] fr LossLayer. Convolutional input is handled via preprocessor here, not CnnLossLayer.
+ ToSameDiffTests.testToSameDiff(net, ds.getFeatures(), null);
}
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java
index 23c4421e57fe..dff0023ab2e4 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java
@@ -32,6 +32,9 @@
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.TestToSameDiff;
+import org.deeplearning4j.samediff.ToSameDiffTests;
+import org.deeplearning4j.util.ToSameDiffUtils;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -144,6 +147,8 @@ public void testDropoutLayerWithoutTraining() throws Exception {
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(4));
assertEquals(actTestIntegrated.get(1), actTestSeparate.get(2));
assertEquals(actTestIntegrated.get(2), actTestSeparate.get(4));
+
+ ToSameDiffTests.testToSameDiff(netIntegrated, in, null);
}
@Test
@@ -297,5 +302,7 @@ public void testDropoutLayerWithConvMnist() throws Exception {
List actTestSeparate = netSeparate.feedForward(false);
assertEquals(actTestIntegrated.get(1), actTestSeparate.get(1));
assertEquals(actTestIntegrated.get(2), actTestSeparate.get(3));
+ ToSameDiffTests.testToSameDiff(netIntegrated, next.getFeatures().dup(), next.getLabels().dup());
+ ToSameDiffTests.testToSameDiff(netSeparate, next.getFeatures().dup(), next.getLabels().dup());
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java
index 65d204964fc3..669bb7484cf6 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java
@@ -27,6 +27,7 @@
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -372,5 +373,7 @@ public void testFrozenLayerInstantiationCompGraph() {
INDArray out3 = net3.outputSingle(input);
assertEquals(out2, out3);
+
+ ToSameDiffTests.testToSameDiff(net2, input, null);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java
index 200f55071d24..54a9e59220c0 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java
@@ -30,6 +30,7 @@
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -98,6 +99,8 @@ public void testFrozenWithBackpropLayerInstantiation() {
INDArray out3 = net3.output(input);
assertEquals(out2, out3);
+
+ ToSameDiffTests.testToSameDiff(net1, input, null);
}
@Test
@@ -153,6 +156,7 @@ public void testFrozenLayerInstantiationCompGraph() {
INDArray out3 = net3.outputSingle(input);
assertEquals(out2, out3);
+ ToSameDiffTests.testToSameDiff(net2, input, null);
}
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java
index 2d746a1372ec..acf736a50651 100755
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java
@@ -22,6 +22,7 @@
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
+import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
@@ -32,6 +33,7 @@
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -332,6 +334,7 @@ public void testCompareRnnOutputRnnLoss(){
assertEquals(mln.gradient().gradient(), mln2.gradient().gradient());
assertEquals(mln.score(), mln2.score(), 1e-6);
+ ToSameDiffTests.testToSameDiff(mln, in, labels);
TestUtils.testModelSerialization(mln);
}
@@ -361,6 +364,7 @@ public void testCnnLossLayer(){
.layer(new CnnLossLayer.Builder(LossFunction.MSE)
.activation(a)
.build())
+ .setInputType(InputType.convolutional(5, 5, 4))
.build();
MultiLayerConfiguration conf2 =
@@ -421,6 +425,7 @@ public void testCnnLossLayer(){
assertArrayEquals(new long[]{2, 1}, s.shape());
assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);
+ ToSameDiffTests.testToSameDiff(mln, in2, labels2);
TestUtils.testModelSerialization(mln);
}
}
@@ -515,6 +520,8 @@ public void testCnnLossLayerCompGraph(){
assertEquals(s.getDouble(0), s.getDouble(1), 1e-6);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, in, labels);
+ ToSameDiffTests.testToSameDiff(graph2, in2, labels2);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java
index 76d14d47d46f..f5dea74d58a1 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java
@@ -27,16 +27,20 @@
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
+import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -924,6 +928,14 @@ public static void testHelper(TestCase tc) {
assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2));
}
+ //TODO LocallyConnected NPEs because of the lack of SDVariable shapes
+ if(!(tc.net1.getnLayers() > 1 && tc.net1.getLayer(1).getConfig() instanceof LocallyConnected2D)) {
+ ToSameDiffTests.testToSameDiff(tc.net1, inNCHW, null);
+ ToSameDiffTests.testToSameDiff(tc.net2, inNCHW, null);
+ ToSameDiffTests.testToSameDiff(tc.net3, inNHWC, null);
+ ToSameDiffTests.testToSameDiff(tc.net4, inNHWC, null);
+ }
+
}
private static List differentGrads(Gradient g1, Gradient g2){
@@ -943,7 +955,7 @@ private static List differentGrads(Gradient g1, Gradient g2){
//Converts NHWC to NCHW activations
@EqualsAndHashCode
- private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
+ private static class NHWCToNCHWPreprocessor extends BaseInputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
@@ -956,7 +968,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w
}
@Override
- public InputPreProcessor clone() {
+ public BaseInputPreProcessor clone() {
return this;
}
@@ -970,6 +982,11 @@ public InputType getOutputType(InputType inputType) {
public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
+
+ @Override
+ public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
+ return input.permute(0, 3, 1, 2);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java
index 1d354ef519de..3595ec327abf 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java
@@ -72,6 +72,7 @@ public void testConvolutionLayerSetup() {
builder.setInputType(InputType.convolutionalFlat(28, 28, 1));
MultiLayerConfiguration completed = complete().build();
MultiLayerConfiguration test = builder.build();
+ test.setInputType(null);
assertEquals(completed, test);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java
index 431831487b94..43586ef4d7d3 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java
@@ -34,6 +34,7 @@
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -693,6 +694,7 @@ public void test1dInputType(){
INDArray in = Nd4j.create(2, 10, 6);
INDArray out = net.output(in);
assertArrayEquals(new long[]{2,7,6}, out.shape());
+ ToSameDiffTests.testToSameDiff(net, in, null);
}
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java
index 01d94e6dcb08..d13348368e90 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java
@@ -26,7 +26,7 @@
* Created by Alex on 19/12/2016.
*/
@EqualsAndHashCode
-public class CustomActivation extends BaseActivationFunction implements IActivation {
+public class CustomActivation extends BaseActivationFunction {
@Override
public INDArray getActivation(INDArray in, boolean training) {
return in;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java
index ff79adf0a3de..d74708bd5627 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java
@@ -20,6 +20,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -28,6 +29,8 @@
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java
index 96ab25267799..18ddf29b47e5 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java
@@ -33,6 +33,7 @@
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
@@ -556,6 +557,7 @@ public void testW2VInits(){
INDArray w = net.getParam("0_W");
assertEquals(vectors, w);
+ ToSameDiffTests.testToSameDiff(net, null, null);
TestUtils.testModelSerialization(net);
//Test same thing for embedding sequence layer:
@@ -573,6 +575,7 @@ public void testW2VInits(){
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
+ .setInputType(InputType.feedForward(10))
.build();
net = new MultiLayerNetwork(conf);
@@ -581,6 +584,7 @@ public void testW2VInits(){
w = net.getParam("0_W");
assertEquals(vectors, w);
+ ToSameDiffTests.testToSameDiff(net, null, null);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java
index e2c38bfad457..d5a76b051728 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java
@@ -39,6 +39,7 @@
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@@ -451,6 +452,7 @@ public void checkSerialization() throws Exception {
assertEquals(out, out2);
+ ToSameDiffTests.testToSameDiff(net, in, null);
MultiLayerNetwork net2 = TestUtils.testModelSerialization(net);
INDArray outDeser = net2.output(in, false);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java
index 699d6bf552c8..26badf332424 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java
@@ -21,6 +21,7 @@
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.nn.conf.GradientNormalization;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.TemporaryFolder;
@@ -93,6 +94,7 @@ public void testYoloActivateScoreBasic() {
.layer(new Yolo2OutputLayer.Builder()
.boundingBoxPriors(bbPrior)
.build())
+ .setInputType(InputType.convolutional(h, w, depth))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
@@ -159,6 +161,8 @@ public void testYoloActivateScoreBasic() {
assertArrayEquals(new long[]{mb,1}, scoreArr1.shape());
assertArrayEquals(new long[]{mb,1}, scoreArr2.shape());
assertNotEquals(scoreArr1, scoreArr2);
+
+ ToSameDiffTests.testToSameDiff(net, input, labels);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java
index 2fef54844c5b..983ce9e94efb 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java
@@ -39,6 +39,8 @@
import org.deeplearning4j.nn.updater.MultiLayerUpdater;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.TestToSameDiff;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.junit.Test;
@@ -176,6 +178,7 @@ public void compareImplementations(){
INDArray p1 = net1.params();
INDArray p2 = net2.params();
assertEquals(p1, p2);
+ ToSameDiffTests.testToSameDiff(net1, InputType.inferInputType(in), in, labels);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java
index 7ddc31220987..70d846e85125 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java
@@ -22,10 +22,12 @@
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
+import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -111,10 +113,12 @@ public void testSerialization(){
.list()
.layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder()
.setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build())
+ .setInputType(InputType.recurrent(4, 10))
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
+ ToSameDiffTests.testToSameDiff(net, null, null);
TestUtils.testModelSerialization(net);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java
index 93566050f39d..e1c2727a3da7 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java
@@ -35,6 +35,7 @@
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -374,6 +375,11 @@ public static void testHelper(TestCase tc) {
assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW
assertEquals(tc.msg, out1, net4a.output(inNWC));
}
+
+ ToSameDiffTests.testToSameDiff(tc.net1, inNCW, null);
+ ToSameDiffTests.testToSameDiff(tc.net2, inNCW, null);
+ ToSameDiffTests.testToSameDiff(tc.net3, inNWC, null);
+ ToSameDiffTests.testToSameDiff(tc.net4, inNWC, null);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java
index 9f60d674dd00..ba902a7d8203 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java
@@ -29,6 +29,7 @@
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -119,6 +120,7 @@ public void testLastTimeStepVertex() {
assertEquals(expOut, outFwd);
TestUtils.testModelSerialization(graph);
+ ToSameDiffTests.testToSameDiff(graph, in, null);
}
@Test
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java
index b2b789d589b0..ba546b7b3ee5 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java
@@ -244,6 +244,8 @@ public void testMismatchedInputLabelLength(){
if(msg == null)
t.printStackTrace();
System.out.println(i);
+
+ //TODO Add checks & error message in RNNOutput layer, etc in loss calculation before reshape.
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java
index 639d3fafdc1c..8be0b29b1e49 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java
@@ -24,6 +24,7 @@
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -117,6 +118,7 @@ public void testSimpleRnn(){
}
+ ToSameDiffTests.testToSameDiff(net, null, null);
TestUtils.testModelSerialization(net);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java
index 0d38d699cea5..47b33d93ddcf 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java
@@ -16,6 +16,7 @@
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -104,6 +105,7 @@ public void testTimeDistributed(){
MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2);
out2 = net2.output(in);
INDArray out3 = net3.output(in);
+ ToSameDiffTests.testToSameDiff(net3, in, labels);
assertEquals(out2, out3);
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java
index 243909bd9b70..78003b6af860 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java
@@ -122,7 +122,8 @@ public void validateInput(INDArray input) {
}
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
return layerInput;
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java
index 4e923bf4aa63..58724c7782db 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java
@@ -28,6 +28,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffDenseVertex;
import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -172,6 +173,7 @@ public void testSameDiffDenseVertex() {
INDArray outMbsd = netSD.output(newIn)[0];
INDArray outMb = netStandard.output(newIn)[0];
assertEquals(outMb, outMbsd);
+ ToSameDiffTests.testToSameDiff(netSD, in, l);
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java
index 8368d3869f26..d8f0dbdf18f9 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java
@@ -30,6 +30,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaLayer;
import org.deeplearning4j.nn.layers.samediff.testlayers.SameDiffSimpleLambdaVertex;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
@@ -130,6 +131,7 @@ public void testSameDiffLamdaLayerBasic(){
INDArray outMbsd = lambda.output(newIn)[0];
INDArray outMb = std.output(newIn)[0];
assertEquals(outMb, outMbsd);
+ ToSameDiffTests.testToSameDiff(lambda, in, labels);
}
}
@@ -216,6 +218,7 @@ public void testSameDiffLamdaVertexBasic(){
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
INDArray outMb = std.output(newIn1, newIn2)[0];
assertEquals(outMb, outMbsd);
+ ToSameDiffTests.testToSameDiff(lambda, new INDArray[]{in1, in2}, new INDArray[]{labels});
}
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java
index 9cbbccaa741b..bf16fb314407 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java
@@ -48,7 +48,8 @@ protected MinimalSameDiffDense(){
}
@Override
- public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java
index 7b78c14fccad..f2a1952b6b0d 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java
@@ -126,7 +126,8 @@ public void initializeParameters(Map params) {
}
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java
index 630b6059c169..f971238f745f 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java
@@ -102,7 +102,8 @@ public void initializeParameters(Map params){
}
@Override
- public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sd, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
SDVariable weights = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java
index 9ada8a1dc8d6..45af1cc12192 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSELossLayer.java
@@ -27,7 +27,7 @@
public class SameDiffMSELossLayer extends SameDiffOutputLayer {
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) {
+ public SDVariable defineLayerAndLoss(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) {
//MSE: 1/nOut * (input-labels)^2
SDVariable diff = layerInput.sub(labels);
return diff.mul(diff).mean(1).sum(0);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java
index cfead640f0ff..f2ff1ff50d34 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java
@@ -44,7 +44,7 @@ public SameDiffMSEOutputLayer(int nIn, int nOut, Activation activation, WeightIn
}
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) {
+ public SDVariable defineLayerAndLoss(SameDiff sameDiff, SDVariable layerInput, SDVariable labels, Map paramTable) {
SDVariable z = sameDiff.mmul(layerInput, paramTable.get("W")).add(paramTable.get("b"));
SDVariable out = activation.asSameDiff("out", sameDiff, z);
//MSE: 1/nOut * (input-labels)^2
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java
index 4282c5e62bbe..deea4be60125 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java
@@ -24,6 +24,7 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
+import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.misc.iter.WSTestDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@@ -208,7 +209,7 @@ public void testWithPreprocessorsMLN() {
}
}
- public static class DupPreProcessor implements InputPreProcessor {
+ public static class DupPreProcessor extends BaseInputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr mgr) {
return mgr.dup(ArrayType.ACTIVATIONS, input);
@@ -220,7 +221,7 @@ public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr w
}
@Override
- public InputPreProcessor clone() {
+ public BaseInputPreProcessor clone() {
return new DupPreProcessor();
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java
index 3139b096a687..682eff5ccfc6 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java
@@ -48,6 +48,7 @@
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.*;
import org.nd4j.linalg.activations.Activation;
@@ -1041,6 +1042,7 @@ public void testEpochCounter() throws Exception {
assertEquals(4, net.getLayerWiseConfigurations().getEpochCount());
+ ToSameDiffTests.testToSameDiff(net, null, null);
MultiLayerNetwork restored = TestUtils.testModelSerialization(net);
assertEquals(4, restored.getLayerWiseConfigurations().getEpochCount());
}
@@ -1242,6 +1244,7 @@ public void testZeroParamNet() throws Exception {
net.fit(ds);
+ ToSameDiffTests.testToSameDiff(net, null, null);
MultiLayerNetwork net2 = TestUtils.testModelSerialization(net);
INDArray out2 = net2.output(ds.getFeatures());
assertEquals(out, out2);
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java
index aa552a859dd9..07b333f775a9 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java
@@ -28,6 +28,7 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.samediff.ToSameDiffTests;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -88,6 +89,8 @@ public void testModelSerializerFrozenLayers() throws Exception {
assertEquals(out, out2);
+ ToSameDiffTests.testToSameDiff(withFrozen, in, null);
+
//Sanity check on train mode:
out = withFrozen.output(in, true);
out2 = restored.output(in, true);
@@ -141,5 +144,6 @@ public void testModelSerializerFrozenLayersCompGraph() throws Exception {
//Sanity check on train mode:
out = withFrozen.outputSingle(true, in);
out2 = restored.outputSingle(true, in);
+ ToSameDiffTests.testToSameDiff(withFrozen, in, null);
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/TestToSameDiff.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/TestToSameDiff.java
new file mode 100644
index 000000000000..3d7c5829ef5b
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/TestToSameDiff.java
@@ -0,0 +1,588 @@
+/*
+ * ******************************************************************************
+ * * Copyright (c) 2020 Konduit K.K.
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.samediff;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.BaseDL4JTest;
+import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder;
+import org.deeplearning4j.nn.conf.Updater;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.conf.layers.PoolingType;
+import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.deeplearning4j.util.ToSameDiffUtils;
+import org.junit.Test;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.activations.IActivation;
+import org.nd4j.linalg.activations.impl.ActivationSoftmax;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.DataSet;
+import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator;
+import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.learning.regularization.Regularization;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
+import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
+
+@Slf4j
+public class TestToSameDiff extends BaseDL4JTest {
+
+ @Override
+ public long getTimeoutMilliseconds() {
+ return super.getTimeoutMilliseconds() * 10;
+ }
+
+ private static final double eps = 1e-2;
+
+ private static final String expectedSummary = "--- Summary ---\n"
+ + "Variables: 30 (8 with arrays)\n"
+ + "Functions: 20 \n"
+ + "SameDiff Function Defs: 0 \n"
+ + "Loss function variables: [loss]\n"
+ + "\n"
+ + "--- Variables ---\n"
+ + "- Name - - Array Shape - - Variable Type - - Data Type- - Output Of Function - - Inputs To Functions -\n"
+ + "input [-1, 1, 28, 28] PLACEHOLDER FLOAT [layer0/inputPreprocessor/reshape]\n"
+ + "layer0/inputPreprocessor/reshape - ARRAY FLOAT layer0/inputPreprocessor/reshape(reshape) [layer0/conv2d] \n"
+ + "layer0/b [1, 20] VARIABLE FLOAT [layer0/conv2d] \n"
+ + "layer0/W [20, 1, 5, 5] VARIABLE FLOAT [layer0/conv2d] \n"
+ + "layer0/conv2d - ARRAY FLOAT layer0/conv2d(conv2d) [layer1/maxpool2d] \n"
+ + "layer1/maxpool2d - ARRAY FLOAT layer1/maxpool2d(maxpool2d) [layer2/conv2d] \n"
+ + "layer2/b [1, 50] VARIABLE FLOAT [layer2/conv2d] \n"
+ + "layer2/W [50, 20, 5, 5] VARIABLE FLOAT [layer2/conv2d] \n"
+ + "layer2/conv2d - ARRAY FLOAT layer2/conv2d(conv2d) [layer3/maxpool2d] \n"
+ + "layer3/maxpool2d - ARRAY FLOAT layer3/maxpool2d(maxpool2d) [layer4/inputPreprocessor/reshape]\n"
+ + "layer4/inputPreprocessor/reshape - ARRAY FLOAT layer4/inputPreprocessor/reshape(reshape) [layer4/mmul] \n"
+ + "layer4/b [1, 500] VARIABLE FLOAT [layer4/add] \n"
+ + "layer4/W [800, 500] VARIABLE FLOAT [layer4/mmul] \n"
+ + "layer4/mmul - ARRAY FLOAT layer4/mmul(mmul) [layer4/add] \n"
+ + "layer4/add - ARRAY FLOAT layer4/add(add) [layer4/relu] \n"
+ + "layer4/relu - ARRAY FLOAT layer4/relu(relu) [layer5/mmul] \n"
+ + "layer5/b [1, 10] VARIABLE FLOAT [layer5/add] \n"
+ + "layer5/W [500, 10] VARIABLE FLOAT [layer5/mmul] \n"
+ + "layer5/mmul - ARRAY FLOAT layer5/mmul(mmul) [layer5/add] \n"
+ + "layer5/add - ARRAY FLOAT layer5/add(add) [layer5/softmax] \n"
+ + "layer5/softmax - ARRAY FLOAT layer5/softmax(softmax) [layer5/loss/ClipByValue]\n"
+ + "labels [-1, 10] PLACEHOLDER FLOAT [layer5/loss/multiply, layer5/loss/size_at]\n"
+ + "layer5/loss/ClipByValue - ARRAY FLOAT layer5/loss/ClipByValue(ClipByValue) [layer5/loss/log] \n"
+ + "layer5/loss/log - ARRAY FLOAT layer5/loss/log(log) [layer5/loss/multiply]\n"
+ + "layer5/loss/multiply - ARRAY FLOAT layer5/loss/multiply(multiply) [layer5/loss/neg] \n"
+ + "layer5/loss/neg - ARRAY FLOAT layer5/loss/neg(neg) [layer5/loss/reduce_sum]\n"
+ + "layer5/loss/reduce_sum - ARRAY FLOAT layer5/loss/reduce_sum(reduce_sum) [layer5/loss/divide]\n"
+ + "layer5/loss/size_at - ARRAY LONG layer5/loss/size_at(size_at) [layer5/loss/cast] \n"
+ + "layer5/loss/cast - ARRAY FLOAT layer5/loss/cast(cast) [layer5/loss/divide]\n"
+ + "loss - ARRAY FLOAT layer5/loss/divide(divide) \n"
+ + "\n"
+ + "\n"
+ + "--- Functions ---\n"
+ + " - Function Name - - Op - - Inputs - - Outputs - \n"
+ + "0 layer0/inputPreprocessor/reshape Reshape [input] [layer0/inputPreprocessor/reshape] \n"
+ + "1 layer0/conv2d Conv2D [layer0/inputPreprocessor/reshape, layer0/W, layer0/b] [layer0/conv2d] \n"
+ + "2 layer1/maxpool2d MaxPooling2D [layer0/conv2d] [layer1/maxpool2d] \n"
+ + "3 layer2/conv2d Conv2D [layer1/maxpool2d, layer2/W, layer2/b] [layer2/conv2d] \n"
+ + "4 layer3/maxpool2d MaxPooling2D [layer2/conv2d] [layer3/maxpool2d] \n"
+ + "5 layer4/inputPreprocessor/reshape Reshape [layer3/maxpool2d] [layer4/inputPreprocessor/reshape] \n"
+ + "6 layer4/mmul Mmul [layer4/inputPreprocessor/reshape, layer4/W] [layer4/mmul] \n"
+ + "7 layer4/add AddOp [layer4/mmul, layer4/b] [layer4/add] \n"
+ + "8 layer4/relu RectifiedLinear [layer4/add] [layer4/relu] \n"
+ + "9 layer5/mmul Mmul [layer4/relu, layer5/W] [layer5/mmul] \n"
+ + "10 layer5/add AddOp [layer5/mmul, layer5/b] [layer5/add] \n"
+ + "11 layer5/softmax SoftMax [layer5/add] [layer5/softmax] \n"
+ + "12 layer5/loss/ClipByValue ClipByValue [layer5/softmax] [layer5/loss/ClipByValue] \n"
+ + "13 layer5/loss/log Log [layer5/loss/ClipByValue] [layer5/loss/log] \n"
+ + "14 layer5/loss/multiply MulOp [layer5/loss/log, labels] [layer5/loss/multiply] \n"
+ + "15 layer5/loss/neg Negative [layer5/loss/multiply] [layer5/loss/neg] \n"
+ + "16 layer5/loss/reduce_sum Sum [layer5/loss/neg] [layer5/loss/reduce_sum] \n"
+ + "17 layer5/loss/size_at SizeAt [labels] [layer5/loss/size_at] \n"
+ + "18 layer5/loss/cast Cast [layer5/loss/size_at] [layer5/loss/cast] \n"
+ + "19 layer5/loss/divide DivOp [layer5/loss/reduce_sum, layer5/loss/cast] [loss] \n";
+
+ @Override
+ public DataType getDataType() {
+ return DataType.DOUBLE;
+ }
+
+ public static void testSameDiffInference(MultiLayerNetwork network, SameDiff sameDiff, INDArray input,
+ String name) {
+ INDArray dl4j = network.output(input.dup());
+ INDArray sd = sameDiff.batchOutput()
+ .input("input", input.dup())
+ .output(sameDiff.outputs().get(0))
+ .outputSingle();
+
+ if (sd.isNaN().any() && dl4j.isNaN().any()) {
+ return;
+ }
+
+// assertEquals("Sums of DL4J and SameDiff outputs differ for " + name, dl4j.sumNumber().doubleValue(), sd.sumNumber().doubleValue(), eps);
+
+ assertTrue("Output of DL4J and SameDiff differ for " + name, dl4j.equalsWithEps(sd, eps));
+ }
+
+ public static void testWeights(MultiLayerNetwork network, SameDiff sameDiff, String name) {
+ List names = ToSameDiffUtils.getScopeNames(network.getLayers());
+ for (int i = 0; i < network.getnLayers(); i++) {
+ String nameScope = names.get(i);
+ for (Map.Entry entry : network.getLayer(i).paramTable().entrySet()) {
+ String paramName = entry.getKey();
+ INDArray dl4j = entry.getValue();
+ INDArray sd = sameDiff.getArrForVarName(nameScope + "/" + paramName);
+
+ assertTrue("Weight " + nameScope + "/" + paramName + " differs for" + name, dl4j.equalsWithEps(sd, eps));
+ }
+ }
+ }
+
+ public static void testBackprop(MultiLayerNetwork network, SameDiff sameDiff, INDArray input, INDArray labels) {
+ network.setInput(input);
+ network.setLabels(labels);
+
+ network.computeGradientAndScore();
+
+ int batchSize = (int) input.size(0);
+
+ double dl4jScore = network.score();
+ double sdScore = sameDiff.batchOutput()
+ .input("labels", labels)
+ .input("input", input)
+ .output(sameDiff.getLossVariables().get(0))
+ .outputSingle().sumNumber().doubleValue();
+ assertEquals("Losses differed", dl4jScore, sdScore, eps);
+
+ Map dl4jGradient = network.gradient().gradientForVariable();
+
+ boolean has2ndLayer = dl4jGradient.containsKey("1_W");
+
+ INDArray dl4jWeightGrad = dl4jGradient.get("0_W");
+ INDArray dl4jBiasGrad = dl4jGradient.get("0_b");
+
+ Map placeholderMap = new HashMap<>();
+ placeholderMap.put("labels", labels);
+ placeholderMap.put("input", input);
+
+ Set gradientVars = new HashSet<>();
+
+ for (String k : sameDiff.variableMap().keySet()) {
+ if (sameDiff.getVariable(k).dataType().isFPType()) {
+ gradientVars.add(k);
+ }
+ }
+
+ Map sameDiffGradient = sameDiff.calculateGradients(placeholderMap, gradientVars);
+
+ // SameDiff does its batch div in the gradient calc, however DL4J does it afterwards
+ for (Map.Entry entry : sameDiffGradient.entrySet()) {
+ entry.setValue(entry.getValue().mul(batchSize));
+ }
+
+ INDArray sdWeightGrad = sameDiffGradient.get("layer0/W");
+ INDArray sdBiasGrad = sameDiffGradient.get("layer0/b");
+
+ assertTrue("Weight 0 gradient differs", dl4jWeightGrad.equalsWithEps(sdWeightGrad, eps));
+ assertTrue("Bias 0 gradient differs", dl4jBiasGrad.equalsWithEps(sdBiasGrad, eps));
+
+ if (has2ndLayer) {
+
+ INDArray dl4jWeightGrad2 = dl4jGradient.get("1_W");
+ INDArray dl4jBiasGrad2 = dl4jGradient.get("1_b");
+
+ INDArray sdWeightGrad2 = sameDiffGradient.get("layer1/W");
+ INDArray sdBiasGrad2 = sameDiffGradient.get("layer1/b");
+
+ assertTrue("Weight 1 gradient differs", dl4jWeightGrad2.equalsWithEps(sdWeightGrad2, eps));
+ assertTrue("Bias 1 gradient differs", dl4jBiasGrad2.equalsWithEps(sdBiasGrad2, eps));
+ }
+ }
+
+ public static void testSameDiffInference(ComputationGraph network, SameDiff sameDiff, INDArray input, String name) {
+ INDArray dl4j = network.output(input)[0];
+ INDArray sd = sameDiff.batchOutput()
+ .input("in", input)
+ .output(sameDiff.outputs().get(0))
+ .outputSingle();
+
+ assertTrue("Output of DL4J and SameDiff differ for " + name, dl4j.equalsWithEps(sd, eps));
+ }
+
+ @Test
+ public void testMcXent() {
+ Nd4j.getRandom().setSeed(123);
+
+ ILossFunction loss = new LossMCXENT();
+ IActivation activation = new ActivationSoftmax();
+
+ INDArray input = Nd4j.rand(5, 4);
+ INDArray labels = Nd4j.rand(5, 4);
+
+ INDArray dl4grad = loss.computeGradient(labels.dup(), input.dup(), activation, null);
+
+ SameDiff sameDiff = SameDiff.create();
+ SDVariable inputVar = sameDiff.placeHolder("input", input.dataType(), input.shape());
+ SDVariable labelsVar = sameDiff.placeHolder("labels", labels.dataType(), labels.shape());
+
+ SDVariable out = sameDiff.nn.softmax(inputVar);
+ // not dividing by batch size as dl4j does it later
+ SDVariable lossVar = sameDiff.math.log(out).mul(labelsVar).neg().sum();
+
+ sameDiff.setLossVariables(lossVar);
+
+ Map placeholderMap = new HashMap<>();
+ placeholderMap.put("input", input.dup());
+ placeholderMap.put("labels", labels.dup());
+
+ sameDiff.createGradFunction("input");
+
+ INDArray sdGrad = sameDiff.calculateGradients(placeholderMap, lossVar.name(), "input").get("input");
+
+ assertTrue(dl4grad.equalsWithEps(sdGrad, eps));
+ }
+
+ @Test
+ public void testSimple() throws IOException {
+ int seed = 123;
+
+ Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
+
+ boolean[] useDenses = {false};
+ Updater[] updaters = {Updater.SGD, Updater.ADAM, Updater.ADAMAX, Updater.ADADELTA, Updater.NESTEROVS,
+ Updater.NADAM/*, Updater.ADAGRAD, Updater.RMSPROP*/, Updater.NONE};
+ Regularization[] regularizations = {null}; // {new L2Regularization(0.0005), new L1Regularization(0.005),
+// new WeightDecay(0.03, true)};
+ LossFunction[] lossFunctions = {LossFunction.MSE, LossFunction.L1/*, LossFunction.MCXENT*/,
+ /*LossFunction.COSINE_PROXIMITY,*/ LossFunction.HINGE,
+ LossFunction.SQUARED_HINGE/*, LossFunction.KL_DIVERGENCE*/, LossFunction.MEAN_ABSOLUTE_ERROR,
+ LossFunction.L2/*, LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR*//*,
+ LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR*//*, LossFunction.POISSON*/, LossFunction.WASSERSTEIN};
+
+ Activation[] activations = {/*Activation.CUBE, */Activation.ELU, Activation.HARDSIGMOID, Activation.HARDTANH,
+ Activation.IDENTITY, Activation.LEAKYRELU, Activation.RATIONALTANH, Activation.RELU, Activation.RELU6,
+ Activation.RRELU, Activation.SIGMOID/*, Activation.SOFTMAX*/, Activation.SOFTPLUS, Activation.SOFTSIGN,
+ Activation.TANH, Activation.RECTIFIEDTANH, Activation.SELU, Activation.SWISH,
+ Activation.THRESHOLDEDRELU, Activation.GELU, Activation.MISH};
+
+ List failures = new ArrayList<>();
+
+ for (Updater updater : updaters) {
+ for (LossFunction lossFunction : lossFunctions) {
+ for (Activation activation : activations) {
+ for (boolean useDense : useDenses) {
+ for (Regularization regularization : regularizations) {
+
+ if (updater == Updater.CUSTOM) {
+ continue;
+ }
+
+ IUpdater iUpdater = updater.getIUpdaterWithDefaultConfig();
+
+ log.info("Test with {}, {}, {}, {}, and {}", useDense ? "dense layer" : "no dense layer",
+ regularization, activation, lossFunction, iUpdater);
+
+ try {
+ Nd4j.getRandom().setSeed(seed);
+
+ ListBuilder partial = new NeuralNetConfiguration.Builder()
+ .seed(seed)
+ .dataType(DataType.DOUBLE)
+ .updater(iUpdater)
+ .regularization(
+ regularization != null ? Collections.singletonList(regularization)
+ : Collections.emptyList())
+ .regularizationBias(
+ regularization != null ? Collections.singletonList(regularization)
+ : Collections.emptyList())
+ .list();
+
+ if (useDense) {
+ partial.layer(new DenseLayer.Builder()
+ .activation(Activation.RELU)
+ .nOut(4).build());
+ }
+
+ MultiLayerConfiguration config = partial
+ .layer(new OutputLayer.Builder(lossFunction)
+ .activation(activation).nIn(4).nOut(3).build())
+ .setInputType(InputType.feedForward(4))
+ .validateOutputLayerConfig(false)
+ .build();
+
+ MultiLayerNetwork network = new MultiLayerNetwork(config);
+ network.init();
+
+ Nd4j.getRandom().setSeed(seed);
+
+ INDArray example = Nd4j.rand(5, 4).mul(2);
+ DataSet ds = new DataSet(Nd4j.rand(5, 4).mul(2), Nd4j.rand(5, 3).mul(2));
+ DataSetIterator iter = new SingletonDataSetIterator(ds);
+
+ // --- training tests ---
+
+ // train DL4J first
+ network.fit(iter, 1);
+ assertEquals(1, network.getIterationCount());
+ assertEquals(1, network.getEpochCount());
+ iter.reset();
+
+ // copy (w/ params and updater state)
+
+ SameDiff mnistSameDiff;
+ try {
+ mnistSameDiff = network.toSameDiff(null, false, false);
+ } catch (UnsupportedOperationException e) {
+ continue;
+ }
+ testSameDiffInference(network, mnistSameDiff, example, "Post DL4J Training");
+
+// testBackprop(network, mnistSameDiff, ds.getFeatures().dup(), ds.getLabels().dup());
+
+ // train 2 more epochs
+// iter.reset();
+// mnistSameDiff.fit(iter, 1);
+// assertEquals(1, mnistSameDiff.getTrainingConfig().getIterationCount());
+// assertEquals(1, mnistSameDiff.getTrainingConfig().getEpochCount());
+//
+// iter.reset();
+// network.fit(iter, 1);
+// assertEquals(1, network.getIterationCount());
+// assertEquals(1, network.getEpochCount());
+//
+// testSameDiffInference(network, mnistSameDiff, example, "Post 1st Training");
+
+ testWeights(network, mnistSameDiff, "Copy");
+
+ iter.reset();
+ mnistSameDiff.fit(iter, 1);
+ assertEquals(2, mnistSameDiff.getTrainingConfig().getIterationCount());
+ assertEquals(2, mnistSameDiff.getTrainingConfig().getEpochCount());
+
+ iter.reset();
+ network.fit(iter, 1);
+ assertEquals(2, network.getIterationCount());
+ assertEquals(2, network.getEpochCount());
+
+ testWeights(network, mnistSameDiff, "Post Train");
+ testSameDiffInference(network, mnistSameDiff, example, "Post 2nd Training");
+ } catch (AssertionError ae) {
+ ae.printStackTrace();
+ failures.add((useDense ? "Dense Layer " : "No Dense Layer ") + " with " + regularization
+ + ", " + activation
+ + ", " + lossFunction
+ + ", and " + iUpdater);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ log.info(" --- Failures --- ");
+ for (String f : failures) {
+ log.info(f);
+ }
+
+ assertTrue("There were failed tests", failures.isEmpty());
+
+ }
+
+ @Test
+ public void testConversionAndTraining() throws IOException {
+ int seed = 123;
+ int outputNum = 10;
+
+ Nd4j.getRandom().setSeed(seed);
+
+ MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
+ .seed(seed)
+ .l2(0.0005)
+ .l2Bias(0.0005)
+ .weightInit(WeightInit.XAVIER)
+ .updater(new Adam())
+ .list()
+ .layer(new ConvolutionLayer.Builder(5, 5)
+ .stride(1, 1)
+ .nOut(20)
+ .activation(Activation.IDENTITY)
+ .build())
+ .layer(new SubsamplingLayer.Builder(PoolingType.MAX)
+ .kernelSize(2, 2)
+ .stride(2, 2)
+ .build())
+ .layer(new ConvolutionLayer.Builder(5, 5)
+ .stride(1, 1)
+ .nOut(50)
+ .activation(Activation.IDENTITY)
+ .build())
+ .layer(new SubsamplingLayer.Builder(PoolingType.MAX)
+ .kernelSize(2, 2)
+ .stride(2, 2)
+ .build())
+ .layer(new DenseLayer.Builder().activation(Activation.RELU)
+ .nOut(500).build())
+ .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
+ .nOut(outputNum)
+ .activation(Activation.SOFTMAX)
+ .build())
+ .setInputType(InputType.convolutionalFlat(28, 28, 1))
+ .build();
+
+ MultiLayerNetwork network = new MultiLayerNetwork(config);
+ network.init();
+
+ Nd4j.getRandom().setSeed(seed);
+ SameDiff mnistSameDiff = network.toSameDiff(null, false, false);
+
+ assertEquals("More than one output", 1, mnistSameDiff.outputs().size());
+ assertEquals("More than one loss", 1, mnistSameDiff.getLossVariables().size());
+ assertNotNull(mnistSameDiff.getTrainingConfig());
+
+ assertEquals("Summaries aren't equal", expectedSummary, mnistSameDiff.summary());
+
+ MnistDataSetIterator trainData = new MnistDataSetIterator(2, 2);
+
+ INDArray example = trainData.next().getFeatures().dup();
+
+ testSameDiffInference(network, mnistSameDiff, example, "Inference");
+
+ // --- training tests ---
+
+ // train DL4J first
+ network.fit(trainData, 1);
+ trainData.reset();
+
+ // copy (w/ params and updater state)
+
+ mnistSameDiff = network.toSameDiff(null, false, false);
+ testSameDiffInference(network, mnistSameDiff, example, "Post DL4J Training");
+
+ // train 2 more epochs
+ trainData.reset();
+ mnistSameDiff.fit(trainData, 1);
+
+ trainData.reset();
+ network.fit(trainData, 1);
+
+ testSameDiffInference(network, mnistSameDiff, example, "Post 2nd Training");
+ }
+
+ @Test
+ public void testConversionAndTrainingGraph() throws IOException {
+ int seed = 123;
+ int outputNum = 10;
+
+ MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
+ .seed(seed)
+// .l2(0.0005)
+// .l2Bias(0.0005)
+// .weightInit(WeightInit.XAVIER)
+ .updater(new Adam(eps))
+ .list()
+ .layer(new ConvolutionLayer.Builder(5, 5)
+ .stride(1, 1)
+ .nOut(20)
+ .activation(Activation.IDENTITY)
+ .build())
+ .layer(new SubsamplingLayer.Builder(PoolingType.MAX)
+ .kernelSize(2, 2)
+ .stride(2, 2)
+ .build())
+ .layer(new ConvolutionLayer.Builder(5, 5)
+ .stride(1, 1)
+ .nOut(50)
+ .activation(Activation.IDENTITY)
+ .build())
+ .layer(new SubsamplingLayer.Builder(PoolingType.MAX)
+ .kernelSize(2, 2)
+ .stride(2, 2)
+ .build())
+ .layer(new DenseLayer.Builder().activation(Activation.RELU)
+ .nOut(500).build())
+ .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
+ .nOut(outputNum)
+ .activation(Activation.SOFTMAX)
+ .build())
+ .setInputType(InputType.convolutionalFlat(28, 28, 1))
+ .build();
+
+ MultiLayerNetwork net = new MultiLayerNetwork(config);
+ net.init();
+
+ ComputationGraph graph = net.toComputationGraph();
+ graph.init();
+
+ Map inputTypes = new HashMap<>();
+ inputTypes.put("in", InputType.convolutionalFlat(28, 28, 1));
+ SameDiff mnistSameDiff = graph.toSameDiff(inputTypes, true, true);
+
+ assertEquals("More than one output", 1, mnistSameDiff.outputs().size());
+ assertEquals("More than one loss", 1, mnistSameDiff.getLossVariables().size());
+ assertNotNull(mnistSameDiff.getTrainingConfig());
+
+ MnistDataSetIterator trainData = new MnistDataSetIterator(10, 10);
+
+ INDArray example = trainData.next().getFeatures().dup();
+
+ testSameDiffInference(graph, mnistSameDiff, example, "Inference");
+
+ // --- training tests ---
+
+ // train DL4J first
+ graph.fit(trainData, 2);
+ trainData.reset();
+
+ // copy (w/ params and updater state)
+
+ mnistSameDiff = graph.toSameDiff(inputTypes, true, false);
+ testSameDiffInference(graph, mnistSameDiff, example, "Post DL4J Training");
+
+ // train 2 more epochs
+ trainData.reset();
+ mnistSameDiff.fit(trainData, 2);
+
+ trainData.reset();
+ graph.fit(trainData, 2);
+
+ testSameDiffInference(graph, mnistSameDiff, example, "Post 2nd Training");
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/ToSameDiffTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/ToSameDiffTests.java
new file mode 100644
index 000000000000..6b1aa9f19951
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/ToSameDiffTests.java
@@ -0,0 +1,763 @@
+/*
+ * ******************************************************************************
+ * * Copyright (c) 2020 Konduit K.K.
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.samediff;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.common.reflect.ClassPath;
+import com.google.common.reflect.ClassPath.ClassInfo;
+import java.io.IOException;
+import java.lang.reflect.Modifier;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.deeplearning4j.nn.api.layers.IOutputLayer;
+import org.deeplearning4j.nn.conf.InputPreProcessor;
+import org.deeplearning4j.nn.conf.dropout.IDropout;
+import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Cnn3DLossLayer;
+import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
+import org.deeplearning4j.nn.conf.layers.Convolution1D;
+import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
+import org.deeplearning4j.nn.conf.layers.Convolution2D;
+import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
+import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
+import org.deeplearning4j.nn.conf.layers.Layer;
+import org.deeplearning4j.nn.conf.layers.LayerWithLoss;
+import org.deeplearning4j.nn.conf.layers.Pooling1D;
+import org.deeplearning4j.nn.conf.layers.Pooling2D;
+import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
+import org.deeplearning4j.nn.conf.layers.Subsampling1DLayer;
+import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
+import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
+import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
+import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
+import org.deeplearning4j.nn.graph.ComputationGraph;
+import org.deeplearning4j.nn.graph.vertex.GraphVertex;
+import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
+import org.deeplearning4j.nn.layers.BaseOutputLayer;
+import org.deeplearning4j.nn.layers.LossLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.util.ToSameDiffUtils;
+import org.junit.runner.Result;
+import org.junit.runner.notification.RunListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.common.base.Preconditions;
+import org.nd4j.common.primitives.Pair;
+import org.nd4j.linalg.activations.IActivation;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
+
+
+@Slf4j
+public class ToSameDiffTests extends RunListener {
+
+ public static boolean SKIP_UNIMPLEMENTED = true;
+ public static boolean FAIL_FAST = true;
+ public static boolean FAIL_IF_MISSING = false;
+ // makes it show up in IDEA test runs
+ public static boolean PRINT_AFTER_EVERY = true;
+
+ private static final Set failurePointLayers = new HashSet<>();
+ private static final Set failurePointVertices = new HashSet<>();
+ private static final Set failureLosses = new HashSet<>();
+
+ private static void cleanupLayers(Set> layers){
+ if(layers.remove(Convolution1D.class))
+ layers.add(Convolution1DLayer.class);
+
+ if(layers.remove(Convolution2D.class))
+ layers.add(ConvolutionLayer.class);
+
+ if(layers.remove(Pooling1D.class))
+ layers.add(Subsampling1DLayer.class);
+
+ if(layers.remove(Pooling2D.class))
+ layers.add(SubsamplingLayer.class);
+ }
+
+ private static void cleanupLosses(Set> layers){
+
+ }
+
+ private static void cleanupDropouts(Set> layers){
+
+ }
+
+ private static void cleanupActivations(Set> layers){
+
+ }
+
+ private static void cleanupPreprocessors(Set> layers){
+
+ }
+
+ private static void cleanupVertices(Set> layers){
+
+ }
+
+ private static Set> findClasses(Class superClass, String topPackage) {
+
+ Set infos;
+ try {
+ infos = ClassPath.from(superClass.getClassLoader()).getTopLevelClassesRecursive(topPackage);
+ } catch (IOException e) {
+ infos = new HashSet<>();
+ }
+
+ Set> classes = new HashSet<>();
+ for(ClassInfo ci : infos){
+ Class> c = ci.load();
+ if(superClass.isAssignableFrom(c) &&
+ !Modifier.isAbstract(c.getModifiers()) &&
+ !c.isInterface() &&
+ !c.getSimpleName().toLowerCase().contains("custom"))
+ classes.add(c.asSubclass(superClass));
+ }
+ return classes;
+ }
+
+ private static Set> findLayers() {
+ Set> ret = findClasses(Layer.class, "org.deeplearning4j.nn.conf.layers");
+ cleanupLayers(ret);
+ return ret;
+ }
+
+ private static Set> findLosses() {
+ Set> ret = findClasses(ILossFunction.class, "org.nd4j.linalg.lossfunctions");
+ cleanupLosses(ret);
+ return ret;
+ }
+
+ private static Set> findDropouts() {
+ Set> ret = findClasses(IDropout.class, "org.deeplearning4j.nn.conf.dropout");
+ cleanupDropouts(ret);
+ return ret;
+ }
+
+ private static Set> findActivations() {
+ Set> ret = findClasses(IActivation.class, "org.nd4j.linalg.activations");
+ cleanupActivations(ret);
+ return ret;
+ }
+
+ private static Set> findPreprocessors() {
+ Set> ret = findClasses(InputPreProcessor.class, "org.deeplearning4j.nn.conf.preprocessor");
+ cleanupPreprocessors(ret);
+ return ret;
+ }
+
+ private static Set> findVertices() {
+ Set> ret = findClasses(GraphVertex.class, "org.deeplearning4j.nn.graph.vertex.impl");
+ cleanupVertices(ret);
+ return ret;
+ }
+
+ private enum Stage{
+ Conversion, Output, Loss;
+
+ public Set> testedLayers = new HashSet<>();
+ public Set> testedLosses = new HashSet<>();
+ public Set> testedDropouts = new HashSet<>();
+ public Set> testedActivations = new HashSet<>();
+ public Set> testedPreprocessors = new HashSet<>();
+ public Set> testedVertices = new HashSet<>();
+
+ public void cleanup(){
+ cleanupLayers(testedLayers);
+ cleanupLosses(testedLosses);
+ cleanupDropouts(testedDropouts);
+ cleanupActivations(testedActivations);
+ cleanupPreprocessors(testedPreprocessors);
+ cleanupVertices(testedVertices);
+ }
+
+ private static Set minusStr(Set> a, Set> b){
+ Set ret = new HashSet<>();
+ for(Class extends T> c : a){
+ if(!b.contains(c))
+ ret.add(c.getSimpleName());
+ }
+ return ret;
+ }
+
+ public int check(Set> foundLayers,
+ Set> foundLosses,
+ Set> foundDropouts,
+ Set> foundActivations,
+ Set> foundPreprocessors,
+ Set> foundVertices
+ ){
+
+ if(this == Stage.Loss){
+ // only care about losses & output/loss layers here
+ foundDropouts.clear();
+ foundActivations.clear();
+ foundPreprocessors.clear();
+ foundVertices.clear();
+
+ Set> old = foundLayers;
+ foundLayers = new HashSet<>();
+ for(Class extends Layer> layer : old){
+ if(LayerWithLoss.class.isAssignableFrom(layer))
+ foundLayers.add(layer);
+ }
+ }
+
+ Set missingLayers = minusStr(foundLayers, testedLayers);
+ Set missingLosses = minusStr(foundLosses, testedLosses);
+ Set missingDropouts = minusStr(foundDropouts, testedDropouts);
+ Set missingActivations = minusStr(foundActivations, testedActivations);
+ Set missingPreprocessors = minusStr(foundPreprocessors, testedPreprocessors);
+ Set missingVertices = minusStr(foundVertices, testedVertices);
+
+ if(this != Stage.Loss)
+ log.info(" --- ToSameDiff {} Tests --- ", name());
+ else
+ log.info(" --- ToSameDiff Loss Tests (only layers that define losses and loss functions are shown) --- ");
+
+ log.info("Missing Layers: {}", missingLayers);
+
+ if(this != Stage.Loss) {
+ log.info("Missing Activations: {}", missingActivations);
+ }
+
+ log.info("Missing Losses: {}", missingLosses);
+
+ if(this != Stage.Loss) {
+ log.info("Missing Preprocessors: {}", missingPreprocessors);
+ log.info("Missing Dropouts: {}", missingDropouts);
+ log.info("Missing Vertices: {}", missingVertices);
+ }
+
+ return missingLayers.size() + missingLosses.size() + missingDropouts.size() +
+ missingActivations.size() + missingPreprocessors.size() + missingVertices.size();
+ }
+
+ public void record(InputPreProcessor preProcessor){
+ if(preProcessor != null) {
+ testedPreprocessors.add(preProcessor.getClass());
+ }
+ }
+
+ public void record(IActivation activation){
+ if(activation != null){
+ testedActivations.add(activation.getClass());
+ }
+ }
+
+ public void record(ILossFunction lossFunction){
+ if(lossFunction != null){
+ testedLosses.add(lossFunction.getClass());
+ }
+ }
+
+ public void record(IDropout dropout){
+ if(dropout != null){
+ testedDropouts.add(dropout.getClass());
+ }
+ }
+
+ public void record(Layer layer){
+ if(layer == null)
+ return;
+
+ testedLayers.add(layer.getClass());
+ record(layer.getIDropout());
+
+ if(layer instanceof BaseWrapperLayer){
+ record(((BaseWrapperLayer) layer).getUnderlying());
+ } else if(layer instanceof FrozenLayer) {
+ record(((FrozenLayer) layer).getLayer());
+ } else if(layer instanceof Bidirectional){
+ record(((Bidirectional) layer).getFwd());
+ record(((Bidirectional) layer).getBwd());
+ }
+
+ if(layer instanceof FeedForwardLayer){
+ record(((FeedForwardLayer) layer).getActivationFn());
+ }
+
+ if(layer instanceof org.deeplearning4j.nn.conf.layers.BaseOutputLayer){
+ record(((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) layer).getLossFn());
+ } else if(layer instanceof CnnLossLayer){
+ record(((CnnLossLayer) layer).getLossFn());
+ } else if(layer instanceof RnnLossLayer){
+ record(((RnnLossLayer) layer).getLossFn());
+ } else if(layer instanceof org.deeplearning4j.nn.conf.layers.LossLayer){
+ record(((org.deeplearning4j.nn.conf.layers.LossLayer) layer).getLossFn());
+ } else if(layer instanceof Cnn3DLossLayer){
+ record(((Cnn3DLossLayer) layer).getLossFn());
+ }
+ }
+
+ public void record(GraphVertex vertex){
+ if(vertex == null)
+ return;
+
+ testedVertices.add(vertex.getClass());
+ if(vertex.hasLayer()){
+ record(vertex.getLayer().conf().getLayer());
+
+ }
+
+ if(vertex instanceof LayerVertex)
+ record(((LayerVertex) vertex).getLayerPreProcessor());
+
+ }
+ }
+
+ public static void testToSameDiff(@NonNull MultiLayerNetwork network, INDArray input, INDArray labels){
+ testToSameDiff(network, null, input, labels);
+ }
+
+ private static ILossFunction getLossFn(org.deeplearning4j.nn.api.Layer layer){
+ ILossFunction lossFn = null;
+ if(layer instanceof BaseOutputLayer){
+ lossFn = ((BaseOutputLayer>) layer).getLossFn();
+ } else if(layer instanceof LossLayer){
+ lossFn = ((LossLayer) layer).getLossFn();
+ } else if(layer instanceof org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer){
+ lossFn = ((org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer) layer).getLossFn();
+ } else if(layer instanceof org.deeplearning4j.nn.layers.convolution.CnnLossLayer){
+ lossFn = ((org.deeplearning4j.nn.layers.convolution.CnnLossLayer) layer).getLossFn();
+ } else if(layer instanceof org.deeplearning4j.nn.layers.recurrent.RnnLossLayer){
+ lossFn = ((org.deeplearning4j.nn.layers.recurrent.RnnLossLayer) layer).getLossFn();
+ }
+ return lossFn;
+ }
+
+ public static void testToSameDiff(@NonNull MultiLayerNetwork network, InputType inputType, INDArray input, INDArray labels){
+
+ for(int i = 0 ; i < network.getnLayers() ; i++){
+ Layer layer = network.getLayer(i).conf().getLayer();
+ Stage.Conversion.record(layer);
+ Stage.Conversion.record(network.getLayerWiseConfigurations().getInputPreProcess(i));
+ }
+
+ SameDiff sameDiff;
+ try{
+ sameDiff = network.toSameDiff(inputType, true, true);
+ } catch (UnsupportedOperationException e){
+ if(!SKIP_UNIMPLEMENTED)
+ throw e;
+ else
+ return;
+ } catch (IllegalStateException e){
+ if((e.getMessage().contains(" convert to SameDiff with different regularizations") ||
+ e.getMessage().contains(" convert to SameDiff with different IUpdaters")) && SKIP_UNIMPLEMENTED)
+ return;
+ else
+ throw e;
+ }
+
+ if(input == null){
+ long[] inputShape = sameDiff.getVariable("input").placeholderShape();
+ for(int i = 0 ; i < inputShape.length ; i++){
+ if(inputShape[i] == -1)
+ inputShape[i] = 1;
+ }
+
+ input = Nd4j.rand(inputShape);
+ }
+
+ for(int i = 0 ; i < network.getnLayers() ; i++){
+ Layer layer = network.getLayer(i).conf().getLayer();
+ Stage.Output.record(layer);
+ Stage.Output.record(network.getLayerWiseConfigurations().getInputPreProcess(i));
+ }
+
+ List activations = network.feedForward(input);
+ activations.remove(0);
+
+ List sdActivationVariables = new ArrayList<>();
+
+
+ List namesByLayer = ToSameDiffUtils.getScopeNames(network.getLayers());
+
+ List layerClassNames = new ArrayList<>();
+ for(int i = 0 ; i < network.getnLayers() ; i++){
+ org.deeplearning4j.nn.conf.layers.Layer config = network.getLayerWiseConfigurations().getConf(i).getLayer();
+
+ String scope = namesByLayer.get(i);
+ List scopeVars = sameDiff.getVariablesInScope(scope);
+ layerClassNames.add(config.getClass().getSimpleName());
+ if(scopeVars.size() > 0) {
+
+ SDVariable lastVar = null;
+ for(int j = scopeVars.size() - 1 ; j >= 0 ; j--){
+ SDVariable variable = scopeVars.get(j);
+
+ if(!variable.name().contains("/loss/") && !variable.name().endsWith("loss") && !variable.name().endsWith("labels")){
+ lastVar = variable;
+ break;
+ }
+
+ }
+
+ if(lastVar != null)
+ sdActivationVariables.add(lastVar.name());
+ else
+ sdActivationVariables.add(sdActivationVariables.get(sdActivationVariables.size() - 1));
+ } else
+ sdActivationVariables.add(sdActivationVariables.get(sdActivationVariables.size() - 1));
+ }
+
+ Map sdActivations = sameDiff.batchOutput()
+ .output(sdActivationVariables.toArray(new String[0]))
+ .input("input", input)
+ .output();
+
+
+ assertEquals("Sizes of DL4J activations and found SameDiff activations differ", activations.size(), sdActivationVariables.size());
+
+ List> messages = new ArrayList<>();
+ boolean failed = false;
+ for(int i = 0 ; i < sdActivationVariables.size() ; i++){
+ INDArray sd = sdActivations.get(sdActivationVariables.get(i));
+ INDArray dl4j = activations.get(i);
+
+ if(! sd.equalsWithEps(dl4j, 1e-3)) {
+
+ if(!failed)
+ failurePointLayers.add(network.getLayer(i).conf().getLayer().getClass().getSimpleName());
+
+ failed = true;
+ if(FAIL_FAST)
+ fail("DL4J activation and SameDiff activation not equal for Layer " + layerClassNames.get(i) + " and SDVariable " + sdActivationVariables.get(i));
+ else
+ messages.add(new Pair<>(layerClassNames.get(i), sdActivationVariables.get(i)));
+ }
+ }
+
+ StringBuilder message = new StringBuilder("DL4J activation and SameDiff activation not equal for ");
+
+ for(Pair pair : messages)
+ message.append("Layer ").append(pair.getFirst()).append(" and SDVariable ").append(pair.getSecond())
+ .append(", ");
+
+ assertEquals(message.toString(), 0, messages.size());
+
+ if(labels != null){
+
+ for(int i = 0 ; i < network.getnLayers() ; i++){
+ Layer layer = network.getLayer(i).conf().getLayer();
+ Stage.Loss.record(layer);
+ Stage.Loss.record(network.getLayerWiseConfigurations().getInputPreProcess(i));
+ }
+
+ INDArray output = network.output(input).dup();
+ network.setLabels(labels);
+ network.computeGradientAndScore();
+ double score = network.score() - network.calcRegularizationScore(true);
+
+ Map sdOutputs = sameDiff.batchOutput()
+ .output(sameDiff.outputs().get(0), sameDiff.getLossVariables().get(0))
+ .input("input", input)
+ .input("labels", labels)
+ .output();
+
+ INDArray sdLoss = sdOutputs.get(sameDiff.getLossVariables().get(0));
+ INDArray sdOutput = sdOutputs.get(sameDiff.outputs().get(0));
+
+
+ assertTrue("Outputs don't match for original network and SameDiff version", sdOutput.equalsWithEps(output, 1e-3));
+
+ double sdScore = sdLoss.sumNumber().doubleValue();
+
+ ILossFunction lossFn = getLossFn(network.getOutputLayer());
+ try {
+ assertEquals("Losses don't match for original network and SameDiff version" + (lossFn != null ?
+ " for loss function " + lossFn.getClass().getSimpleName() : ""),
+ sdScore, score, 1e-3);
+ } catch (AssertionError ae){
+ if(ae.getMessage().contains("Losses don't match") && lossFn != null){
+ failureLosses.add(lossFn.getClass().getSimpleName());
+ }
+ throw ae;
+ }
+ }
+
+ if(PRINT_AFTER_EVERY) {
+ printResults();
+ }
+
+ }
+
+ public static void testToSameDiff(@NonNull ComputationGraph graph, @NonNull INDArray inputs, INDArray labels){
+ INDArray[] labelsArray = null;
+ if(labels != null)
+ labelsArray = new INDArray[]{labels};
+
+ testToSameDiff(graph, new INDArray[]{inputs}, labelsArray);
+ }
+
+ public static void testToSameDiff(@NonNull ComputationGraph graph, @NonNull INDArray[] inputs, INDArray[] labels){
+ testToSameDiff(graph, inputs, labels, null);
+ }
+
+ public static void testToSameDiff(@NonNull ComputationGraph graph, @NonNull INDArray[] inputs, INDArray[] labels, InputType[] inputTypes){
+ Preconditions.checkArgument(inputs.length == graph.getConfiguration().getNetworkInputs().size(),
+ "Didn't supply the right number of inputs: expected %s, got %s", graph.getConfiguration().getNetworkInputs().size(), inputs.length);
+
+ Map inputTypesMap = new HashMap<>();
+ Map inputsMap = new HashMap<>();
+
+ for(int i = 0 ; i < inputs.length ; i++){
+ String name = graph.getConfiguration().getNetworkInputs().get(i);
+ inputsMap.put(name, inputs[i]);
+
+ if(inputTypes != null && inputTypes.length > i && inputTypes[i] != null)
+ inputTypesMap.put(name, inputTypes[i]);
+ else
+ inputTypesMap.put(name, InputType.inferInputType(inputs[i]));
+ }
+
+ InputType[] inputVertTypes = new InputType[inputTypesMap.size()];
+ int j = 0;
+ for(String inputName : graph.getConfiguration().getNetworkInputs()){
+ inputVertTypes[j] = inputTypesMap.get(inputName);
+ j++;
+ }
+
+ try {
+ graph.getConfiguration().getLayerActivationTypes(true, inputVertTypes);
+ } catch (Exception e){
+ log.warn("Error getting activation types and adding preprocessors for graph", e);
+ }
+
+ for(GraphVertex vertex : graph.getVertices()){
+ Stage.Conversion.record(vertex);
+ }
+
+ SameDiff sameDiff;
+ try{
+ sameDiff = graph.toSameDiff(inputTypesMap, true, true);
+ } catch (UnsupportedOperationException e){
+ if(!SKIP_UNIMPLEMENTED)
+ throw e;
+ else
+ return;
+ } catch (IllegalStateException e){
+ if((e.getMessage().contains(" convert to SameDiff with different regularizations") ||
+ e.getMessage().contains(" convert to SameDiff with different IUpdaters") ||
+ e.getMessage().equals("Dimension must be set for toSameDiff conversion.")) &&
+ SKIP_UNIMPLEMENTED)
+ return;
+ else
+ throw e;
+ }
+
+ for(GraphVertex vertex : graph.getVertices()){
+ Stage.Output.record(vertex);
+ }
+
+ Map activations = graph.feedForward(inputs, false);
+
+ for(String inputName : inputsMap.keySet())
+ activations.remove(inputName);
+
+ List activationKeys = new ArrayList<>();
+ for(String n : graph.getConfiguration().getTopologicalOrderStr()){
+ if(activations.containsKey(n))
+ activationKeys.add(n);
+ }
+
+ Map sdActivationVariables = new HashMap<>();
+ for(String vertexName : activationKeys){
+ List scopeVars = sameDiff.getVariablesInScope(vertexName);
+ if(!scopeVars.isEmpty()){
+ SDVariable lastVar = null;
+ for(int i = scopeVars.size() - 1 ; i >= 0 ; i--){
+ SDVariable variable = scopeVars.get(i);
+
+ if(!variable.name().contains("/loss/") && !variable.name().endsWith("loss") && !variable.name().endsWith("labels")){
+ lastVar = variable;
+ break;
+ }
+
+ }
+
+ if(lastVar != null)
+ sdActivationVariables.put(vertexName, lastVar);
+ else {
+ List vertexInputs = graph.getConfiguration().getVertexInputs().get(vertexName);
+ if(vertexInputs.size() == 1){
+ sdActivationVariables.put(vertexName, sdActivationVariables.get(vertexInputs.get(0)));
+ }
+ }
+ }
+ }
+
+ Map sdActivations = sameDiff.batchOutput()
+ .inputs(inputsMap)
+ .output(sdActivationVariables.values().toArray(new SDVariable[0]))
+ .output();
+
+ assertEquals("Sizes of DL4J activations and found SameDiff activations differ", activations.size(), sdActivationVariables.size());
+
+
+ List> messages = new ArrayList<>();
+ boolean failed = false;
+ for(String vertexName : activations.keySet()){
+ INDArray dl4j = activations.get(vertexName);
+ INDArray sd = sdActivations.get(sdActivationVariables.get(vertexName).name());
+
+ if(! sd.equalsWithEps(dl4j, 1e-3)) {
+ GraphVertex vertex = graph.getVertex(vertexName);
+
+ if(!failed){
+ if(vertex instanceof LayerVertex)
+ failurePointLayers.add(vertex.getLayer().conf().getLayer().getClass().getSimpleName());
+ else
+ failurePointVertices.add(vertex.getClass().getSimpleName());
+ }
+
+ failed = true;
+
+ String vertexStr = vertexName + "[" + vertex.getClass().getSimpleName();
+
+ if(vertex.hasLayer())
+ vertexStr += "(" + vertex.getLayer().conf().getLayer().getClass().getSimpleName() + ")";
+
+ vertexStr += "]";
+
+ if(FAIL_FAST)
+ fail("DL4J activation and SameDiff activation not equal for Vertex " + vertexStr + " and SDVariable " + sdActivationVariables.get(vertexName).name());
+ else
+ messages.add(new Pair<>(vertexStr, sdActivationVariables.get(vertexName).name()));
+ }
+
+ }
+
+ StringBuilder message = new StringBuilder("DL4J activation and SameDiff activation not equal for ");
+
+ for(Pair pair : messages)
+ message.append("Layer ").append(pair.getFirst()).append(" and SDVariable ").append(pair.getSecond())
+ .append(", ");
+
+ assertEquals(message.toString(), 0, messages.size());
+
+ if(sameDiff.getTrainingConfig() != null && labels != null) {
+
+ for(GraphVertex vertex : graph.getVertices()){
+ Stage.Loss.record(vertex);
+ }
+
+ List labelNames = sameDiff.getTrainingConfig().getDataSetLabelMapping();
+ Map inputAndLabelMap = new HashMap<>(inputsMap);
+ Preconditions.checkArgument(labels.length == labelNames.size(),
+ "Didn't supply the right number of labels: expected %s, got %s", labelNames.size(), labels.length);
+
+ for (int i = 0; i < labels.length; i++) {
+ inputAndLabelMap.put(labelNames.get(i), labels[i]);
+ }
+
+ graph.setLabels(labels);
+ graph.computeGradientAndScore();
+ double score = graph.score() - graph.calcRegularizationScore(true);
+
+ Map sdLosses = sameDiff.batchOutput()
+ .inputs(inputAndLabelMap)
+ .output(sameDiff.getLossVariables().toArray(new String[0]))
+ .output();
+
+ double sdScore = 0;
+ for(INDArray scoreArr : sdLosses.values())
+ sdScore += scoreArr.sumNumber().doubleValue();
+
+ Set lossFunctions = new HashSet<>();
+ for(String name : graph.getConfiguration().getNetworkOutputs()){
+ GraphVertex vertex = graph.getVertex(name);
+ if(vertex.hasLayer()){
+ ILossFunction lossFn = getLossFn(vertex.getLayer());
+ if(lossFn != null)
+ lossFunctions.add(lossFn.getClass().getSimpleName());
+ }
+ }
+
+ try {
+ assertEquals("Losses don't match for original network and SameDiff version, with loss functions " + lossFunctions,
+ sdScore, score, 1e-3);
+ } catch (AssertionError ae){
+ if(ae.getMessage().contains("Losses don't match") && !lossFunctions.isEmpty()){
+ failureLosses.addAll(lossFunctions);
+ }
+ throw ae;
+ }
+ }
+
+ if(PRINT_AFTER_EVERY) {
+ printResults();
+ }
+ }
+
+ private static final Set> foundLayers = findLayers();
+ private static final Set> foundLosses = findLosses();
+ private static final Set> foundDropouts = findDropouts();
+ private static final Set> foundActivations = findActivations();
+ private static final Set> foundPreprocessors = findPreprocessors();
+ private static final Set> foundVertices = findVertices();
+
+ public static int printResults() {
+ int conversion = Stage.Conversion.check(foundLayers, foundLosses, foundDropouts, foundActivations, foundPreprocessors, foundVertices);
+ int output = Stage.Output.check(foundLayers, foundLosses, foundDropouts, foundActivations, foundPreprocessors, foundVertices);
+ int loss = Stage.Loss.check(foundLayers, foundLosses, foundDropouts, foundActivations, foundPreprocessors, foundVertices);
+
+ if(!(failurePointVertices.isEmpty() && failureLosses.isEmpty() && failurePointLayers.isEmpty())){
+ log.info(" --- ToSameDiff Failure Points --- ");
+ }
+
+ if(!failurePointLayers.isEmpty()){
+ log.info("Failure point layers: {}", failurePointLayers);
+ }
+
+ if(!failurePointVertices.isEmpty()){
+ log.info("Failure point vertices: {}", failurePointVertices);
+ }
+
+ if(!failureLosses.isEmpty()){
+ log.info("Failed losses: {}", failureLosses);
+ }
+
+ return conversion + output + loss;
+ }
+
+ @Override
+ public void testRunFinished(Result result) throws Exception {
+ int failCount = printResults();
+
+ if(FAIL_IF_MISSING){
+ assertEquals("There were missing ToSameDiff tests", 0, failCount);
+ } else if(failCount > 0){
+ log.warn("There were {} missing ToSameDiff tests", failCount);
+ }
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java
index 23c46835e8b7..3fadb25d9b23 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java
@@ -28,7 +28,6 @@
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.lossfunctions.SameDiffLoss;
import java.io.File;
import java.io.InputStream;
@@ -46,10 +45,10 @@ public class KerasCustomLossTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
- public class LogCosh extends SameDiffLoss {
+ public class LogCosh extends SameDiffNonFusedLoss {
@Override
- public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
- return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
+ public SDVariable defineLossArray(SameDiff sameDiff, SDVariable layerInput, SDVariable labels) {
+ return sameDiff.math.log(sameDiff.math.cosh(labels.sub(layerInput)));
}
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java
index 4c860f7c8c5d..af112c2242c1 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/IOutputLayer.java
@@ -20,6 +20,7 @@
import org.deeplearning4j.nn.api.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
/**
* Interface for output layers (those that calculate gradients with respect to a labels array)
@@ -64,6 +65,4 @@ public interface IOutputLayer extends Layer, Classifier {
* @return A column INDArray of shape [numExamples,1], where entry i is the score of the ith example
*/
INDArray computeScoreForExamples(double fullNetworkRegScore, LayerWorkspaceMgr workspaceMgr);
-
-
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java
index 92b98eff5136..1f701f442df8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java
@@ -17,8 +17,13 @@
package org.deeplearning4j.nn.conf;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@@ -31,6 +36,9 @@
* for pre processing input before passing it
* to the neural network.
*
+ * You will most likely want to extend BaseInputPreProcessor when creating a custom preprocessor,
+ * as it supplies default exception-throwing define* methods.
+ *
* @author Adam Gibson
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@@ -69,4 +77,15 @@ public interface InputPreProcessor extends Serializable, Cloneable {
Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize);
+
+ /**
+ * Define the InputPreProcessor's input transformation in a {@link SameDiff} instance.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException}
+ * like the default implementation in {@link BaseInputPreProcessor}.
+ *
+ * @param sameDiff The {@link SameDiff} instance.
+ * @param input The input to transform.
+ * @return The transformed input.
+ */
+ @NonNull SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input);
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java
index 1a61acc4dbd2..0527f116314f 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java
@@ -90,6 +90,9 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
//Counter for the number of epochs completed so far. Used for per-epoch schedules
protected int epochCount = 0;
+ @Getter
+ protected InputType inputType;
+
public int getEpochCount() {
return epochCount;
}
@@ -715,6 +718,7 @@ public MultiLayerConfiguration build() {
conf.inferenceWorkspaceMode = inferenceWorkspaceMode;
conf.cacheMode = cacheMode;
conf.dataType = dataType;
+ conf.inputType = inputType;
Nd4j.getRandom().setSeed(conf.getConf(0).getSeed());
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java
index 462bc9f17c90..b4ad88ed4428 100755
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java
@@ -101,6 +101,8 @@ public class NeuralNetConfiguration implements Serializable, Cloneable {
//Counter for the number of epochs completed so far. Used for per-epoch schedules
protected int epochCount = 0;
+// protected IUpdater iUpdater;
+
/**
* Creates and returns a deep copy of the configuration.
@@ -1094,6 +1096,7 @@ public NeuralNetConfiguration build() {
conf.miniBatch = miniBatch;
conf.cacheMode = this.cacheMode;
conf.dataType = this.dataType;
+// conf.iUpdater = iUpdater;
configureLayer(layer);
if (layer instanceof FrozenLayer) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java
index aa2ed34781ab..a1c6736718dd 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/AlphaDropout.java
@@ -55,7 +55,7 @@
@EqualsAndHashCode(exclude = {"lastPValue","alphaPrime","a","b", "mask"})
@ToString(exclude = {"lastPValue","alphaPrime","a","b"})
@JsonIgnoreProperties({"lastPValue", "alphaPrime", "a", "b", "mask"})
-public class AlphaDropout implements IDropout {
+public class AlphaDropout extends BaseDropout {
public static final double DEFAULT_ALPHA = 1.6732632423543772;
public static final double DEFAULT_LAMBDA = 1.0507009873554804;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/BaseDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/BaseDropout.java
new file mode 100644
index 000000000000..10298b977ab0
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/BaseDropout.java
@@ -0,0 +1,35 @@
+/*
+ * ******************************************************************************
+ * * Copyright (c) 2020 Konduit K.K.
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.nn.conf.dropout;
+
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+
+public abstract class BaseDropout implements IDropout {
+ @Override
+ public SDVariable defineDropout(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
+ throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
+
+ @Override
+ public IDropout clone() {
+ throw new UnsupportedOperationException("Clone not implemented for " + this.getClass().getSimpleName());
+ }
+}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
index acb6afa2c8ee..7e8c3e367972 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
@@ -19,10 +19,13 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
+import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -69,7 +72,7 @@
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
@Slf4j
-public class Dropout implements IDropout {
+public class Dropout extends BaseDropout {
/**
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed?
@@ -242,6 +245,14 @@ public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iterat
return gradAtInput;
}
+ @Override
+ public SDVariable defineDropout(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
+ if(pSchedule != null)
+ throw new UnsupportedOperationException("Scheduled dropout is not supported for SameDiff conversion");
+
+ return sameDiff.nn.dropout(input, p);
+ }
+
@Override
public void clear() {
mask = null;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java
index cd25718bc77a..7c3015cc9205 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianDropout.java
@@ -49,7 +49,7 @@
@Data
@JsonIgnoreProperties({"noise"})
@EqualsAndHashCode(exclude = {"noise"})
-public class GaussianDropout implements IDropout {
+public class GaussianDropout extends BaseDropout {
private final double rate;
private final ISchedule rateSchedule;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java
index d165614abca7..319c720cbd57 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/GaussianNoise.java
@@ -33,7 +33,7 @@
* @author Alex Black
*/
@Data
-public class GaussianNoise implements IDropout {
+public class GaussianNoise extends BaseDropout {
private double stddev;
private ISchedule stddevSchedule;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java
index 43ba15898a0d..2e3d7a9099aa 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/IDropout.java
@@ -16,7 +16,10 @@
package org.deeplearning4j.nn.conf.dropout;
+import lombok.NonNull;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -59,4 +62,15 @@ public interface IDropout extends Serializable, Cloneable {
void clear();
IDropout clone();
+
+ /**
+ * Define the dropout for a {@link SameDiff} instance.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException}
+ * like the default implementation in {@link BaseDropout}.
+ *
+ * @param sameDiff The {@link SameDiff} instance
+ * @param input The input to the dropout, typically the output of the previous layer.
+ * @return The score (loss function value).
+ */
+ @NonNull SDVariable defineDropout(@NonNull SameDiff sameDiff, @NonNull SDVariable input);
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java
index 4d83e7fe8701..e9fd147c9e40 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/SpatialDropout.java
@@ -46,7 +46,7 @@
@Data
@JsonIgnoreProperties({"mask"})
@EqualsAndHashCode(exclude = {"mask"})
-public class SpatialDropout implements IDropout {
+public class SpatialDropout extends BaseDropout {
private double p;
private ISchedule pSchedule;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java
index 65515153d15d..f2d0721c9a62 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/StackVertex.java
@@ -18,6 +18,7 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeRecurrent;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
@@ -79,6 +80,15 @@ public String toString() {
return "StackVertex()";
}
+ private boolean compatibleInputTypes(InputType original, InputType other){
+ if(original instanceof InputTypeRecurrent && other instanceof InputTypeRecurrent){
+ return ((InputTypeRecurrent) original).getFormat().equals(((InputTypeRecurrent) other).getFormat()) &&
+ ((InputTypeRecurrent) original).getSize() == ((InputTypeRecurrent) other).getSize();
+ } else {
+ return original.equals(other);
+ }
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException {
if (vertexInputs.length == 1)
@@ -87,11 +97,12 @@ public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws
//Check that types are all the same...
for( int i=1; i paramTable) {
+ return activationFn.defineActivation(sameDiff, layerInput);
+ }
+
@Override
public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java
index 7abe0da061f6..57395d7bb3ca 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java
@@ -25,6 +25,8 @@
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.util.NetworkUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater;
@@ -145,6 +147,16 @@ public List getRegularizationByParam(String paramName){
return null;
}
+ /**
+ * Applies the activation function if it isn't null.
+ */
+ protected SDVariable doActivation(@NonNull SDVariable input){
+ if(activationFn != null)
+ return activationFn.defineActivation(input.getSameDiff(), input);
+ else
+ return input;
+ }
+
@SuppressWarnings("unchecked")
@Getter
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java
index e8de58d9fa7f..8774f08284ed 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseOutputLayer.java
@@ -20,18 +20,17 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
-import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
-import org.nd4j.linalg.lossfunctions.impl.LossMSE;
-import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
-public abstract class BaseOutputLayer extends FeedForwardLayer {
+public abstract class BaseOutputLayer extends FeedForwardLayer implements LayerWithLoss {
protected ILossFunction lossFn;
protected boolean hasBias = true;
@@ -79,6 +78,11 @@ public LayerMemoryReport getMemoryReport(InputType inputType) {
.build();
}
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ throw new UnsupportedOperationException("SameDiff loss conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
@Getter
@Setter
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java
index 0b98dfad9d32..e10b28cf3ee5 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java
@@ -16,18 +16,22 @@
package org.deeplearning4j.nn.conf.layers;
+import java.util.Map;
import lombok.*;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional.Mode;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import java.util.Arrays;
import java.util.List;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
@Data
@NoArgsConstructor
@@ -44,6 +48,27 @@ protected BaseRecurrentLayer(Builder builder) {
this.rnnDataFormat = builder.rnnDataFormat;
}
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ @NonNull Map paramTable, SDVariable mask, boolean backwards){
+ throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return defineLayer(sameDiff, layerInput, paramTable, mask, false);
+ }
+
+ /**
+ * An optional method to implement that if implemented, defines the bidirectional operation as a single pass.
+ * If not defined, should throw a {@link UnsupportedOperationException}, in which case the forward and backward
+ * passes are done seperatly and combined.
+ */
+ public SDVariable defineBidirectional(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ @NonNull Map paramTable, SDVariable mask, Mode mode) {
+ throw new UnsupportedOperationException("Bidirectional toSameDiff not supported for " + this.getClass().getSimpleName());
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java
index dcced3aeb2bc..3809421095a3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java
@@ -30,8 +30,11 @@
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.regularization.Regularization;
@@ -108,6 +111,55 @@ public ParamInitializer initializer() {
return BatchNormalizationParamInitializer.getInstance();
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ if(lockGammaBeta)
+ throw new UnsupportedOperationException("Locked Gamma & Beta not supported for SameDiff conversion");
+ if(useLogStd)
+ throw new UnsupportedOperationException("LogStd not supported for SameDiff conversion");
+
+ INDArray beta = params.get(BatchNormalizationParamInitializer.BETA);
+ INDArray gamma = params.get(BatchNormalizationParamInitializer.GAMMA);
+ INDArray mean = params.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
+ INDArray variance = params.get(BatchNormalizationParamInitializer.GLOBAL_VAR);
+
+ params.put(BatchNormalizationParamInitializer.BETA, Nd4j.squeeze(beta, 0));
+ params.put(BatchNormalizationParamInitializer.GAMMA, Nd4j.squeeze(gamma, 0));
+ params.put(BatchNormalizationParamInitializer.GLOBAL_MEAN, Nd4j.squeeze(mean, 0));
+ params.put(BatchNormalizationParamInitializer.GLOBAL_VAR, Nd4j.squeeze(variance, 0));
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ if(lockGammaBeta)
+ throw new UnsupportedOperationException("Locked Gamma & Beta not supported for SameDiff conversion");
+ if(useLogStd)
+ throw new UnsupportedOperationException("LogStd not supported for SameDiff conversion");
+
+ SDVariable beta = paramTable.get(BatchNormalizationParamInitializer.BETA);
+ SDVariable gamma = paramTable.get(BatchNormalizationParamInitializer.GAMMA);
+ SDVariable mean = paramTable.get(BatchNormalizationParamInitializer.GLOBAL_MEAN);
+ SDVariable variance = paramTable.get(BatchNormalizationParamInitializer.GLOBAL_VAR);
+
+ int axis;
+ if(cnn2DFormat == CNN2DFormat.NCHW)
+ axis = 1;
+ else if(cnn2DFormat == CNN2DFormat.NHWC)
+ axis = 3;
+ else
+ throw new UnsupportedOperationException("Unknown CNN data format " + cnn2DFormat);
+
+ SDVariable output = sameDiff.nn.batchNorm(layerInput,
+ mean,
+ variance,
+ gamma,
+ beta,
+ eps,
+ axis);
+ return doActivation(output);
+ }
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java
index 5769f2a83e24..6156e11b58a8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java
@@ -99,7 +99,7 @@ public void setNIn(InputType inputType, boolean override) {
}
@Override
- public SDVariable defineLayer(SameDiff sd, SDVariable input, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sd, SDVariable input, SDVariable mask, Map paramTable) {
// input: [mb, inputCapsules, inputCapsuleDimensions]
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java
index 34213038a918..13028b263f98 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java
@@ -24,7 +24,10 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.CenterLossParamInitializer;
+import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java
index 1bde3d912e47..d7f588a8d808 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java
@@ -22,13 +22,17 @@
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import java.util.Collection;
@@ -59,7 +63,7 @@
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
-public class Cnn3DLossLayer extends FeedForwardLayer {
+public class Cnn3DLossLayer extends FeedForwardLayer implements LayerWithLoss {
protected ILossFunction lossFn;
protected Convolution3D.DataFormat dataFormat;
@@ -89,6 +93,49 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ SDVariable batch = sameDiff.sizeAt(layerInput, 0);
+ SDVariable channels;
+ SDVariable depth;
+ SDVariable height;
+ SDVariable width;
+
+ if(dataFormat == DataFormat.NCDHW){
+ channels = sameDiff.sizeAt(layerInput, 1);
+ depth = sameDiff.sizeAt(layerInput, 2);
+ height = sameDiff.sizeAt(layerInput, 3);
+ width = sameDiff.sizeAt(layerInput, 4);
+ layerInput = layerInput.permute(0, 2, 3, 4, 1);
+ } else if(dataFormat == DataFormat.NDHWC){
+ depth = sameDiff.sizeAt(layerInput, 1);
+ height = sameDiff.sizeAt(layerInput, 2);
+ width = sameDiff.sizeAt(layerInput, 3);
+ channels = sameDiff.sizeAt(layerInput, 4);
+ } else
+ throw new UnsupportedOperationException("Unknown CNN 3D data format " + dataFormat);
+
+ SDVariable distributedInput = layerInput.reshape(sameDiff.concat(0, sameDiff.constant(
+ Nd4j.scalar(batch.dataType(), -1)), channels));
+
+ SDVariable distributedOutput = doActivation(distributedInput);
+
+ SDVariable output = distributedOutput.reshape(sameDiff.concat(0, batch, depth, height, width, channels));
+
+ if(dataFormat == DataFormat.NCDHW)
+ return output.permute(0, 4, 1, 2, 3);
+ else
+ return output;
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, labels, average);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || (inputType.getType() != InputType.Type.CNN3D
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java
index 647b187e38ec..617b05e6221f 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -30,9 +31,12 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@@ -60,7 +64,7 @@
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
-public class CnnLossLayer extends FeedForwardLayer {
+public class CnnLossLayer extends FeedForwardLayer implements LayerWithLoss {
protected ILossFunction lossFn;
protected CNN2DFormat format = CNN2DFormat.NCHW;
@@ -90,6 +94,46 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ SDVariable batch = sameDiff.sizeAt(layerInput, 0);
+ SDVariable channels;
+ SDVariable height;
+ SDVariable width;
+
+ if(format == CNN2DFormat.NCHW){
+ channels = sameDiff.sizeAt(layerInput, 1);
+ height = sameDiff.sizeAt(layerInput, 2);
+ width = sameDiff.sizeAt(layerInput, 3);
+ layerInput = layerInput.permute(0, 2, 3, 1);
+ } else if(format == CNN2DFormat.NHWC){
+ height = sameDiff.sizeAt(layerInput, 1);
+ width = sameDiff.sizeAt(layerInput, 2);
+ channels = sameDiff.sizeAt(layerInput, 3);
+ } else
+ throw new UnsupportedOperationException("Unknown CNN data format " + format);
+
+ SDVariable distributedInput = layerInput.reshape(sameDiff.concat(0, sameDiff.constant(
+ Nd4j.scalar(batch.dataType(), -1)), channels));
+
+ SDVariable distributedOutput = doActivation(distributedInput);
+
+ SDVariable output = distributedOutput.reshape(sameDiff.concat(0, batch, height, width, channels));
+
+ if(format == CNN2DFormat.NCHW)
+ return output.permute(0, 3, 1, 2);
+ else
+ return output;
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, labels, average);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || (inputType.getType() != InputType.Type.CNN
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java
index f4d247670f53..8c38266a9d41 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java
@@ -19,20 +19,28 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
+import org.nd4j.linalg.factory.Nd4j;
/**
* 1D (temporal) convolutional layer. This layer accepts RNN InputTypes instead of CNN InputTypes
@@ -77,6 +85,41 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ INDArray weight = params.get(ConvolutionParamInitializer.WEIGHT_KEY);
+ params.put(ConvolutionParamInitializer.WEIGHT_KEY, Nd4j.squeeze(weight, 3).permute(2, 1, 0));
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
+
+ PaddingMode paddingMode;
+
+ if(convolutionMode == ConvolutionMode.Same)
+ paddingMode = PaddingMode.SAME;
+ else if(convolutionMode == ConvolutionMode.Causal)
+ paddingMode = PaddingMode.CAUSAL;
+ else
+ paddingMode = PaddingMode.VALID;
+
+ SDVariable value = sameDiff.cnn.conv1d(layerInput, weight, bias,
+ Conv1DConfig.builder()
+ .dataFormat(rnnDataFormat.name())
+ .paddingMode(paddingMode)
+ .k(kernelSize[0])
+ .s(stride[0])
+ .p(padding[0])
+ .d(dilation[0])
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java
index dc88116e5de5..3d70dbc4e338 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf.layers;
+import java.util.HashMap;
import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -29,11 +30,14 @@
import org.deeplearning4j.util.Convolution3DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
/**
* 3D convolution layer configuration
@@ -113,6 +117,33 @@ public ParamInitializer initializer() {
return Convolution3DParamInitializer.getInstance();
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ INDArray weight = params.get(Convolution3DParamInitializer.WEIGHT_KEY);
+ params.put(Convolution3DParamInitializer.WEIGHT_KEY, weight.permute(2, 3, 4, 1, 0));
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(Convolution3DParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(Convolution3DParamInitializer.BIAS_KEY);
+
+ SDVariable value = sameDiff.cnn.conv3d(layerInput, weight, bias,
+ Conv3DConfig.builder()
+ .dataFormat(this.dataFormat.name())
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .kD(kernelSize[0]).kH(kernelSize[1]).kW(kernelSize[2])
+ .sD(stride[0]).sH(stride[1]).sW(stride[2])
+ .pD(padding[0]).pH(padding[1]).pW(padding[2])
+ .dD(dilation[0]).dH(dilation[1]).dW(dilation[2])
+ .biasUsed(hasBias)
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java
index 3b2d4c0befe7..c7cad4e0e07a 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java
@@ -16,10 +16,25 @@
package org.deeplearning4j.nn.conf.layers;
-import lombok.*;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.Setter;
+import lombok.ToString;
+import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
-import org.deeplearning4j.nn.conf.*;
+import org.deeplearning4j.nn.conf.CNN2DFormat;
+import org.deeplearning4j.nn.conf.CacheMode;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
+import org.deeplearning4j.nn.conf.InputPreProcessor;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
@@ -27,14 +42,13 @@
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
+import org.nd4j.enums.WeightsFormat;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
/**
* 2D Convolution layer (for example, spatial convolution over images). Input activations should be format {@code
@@ -184,6 +198,27 @@ public ParamInitializer initializer() {
return ConvolutionParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
+
+ SDVariable value = sameDiff.cnn.conv2d(layerInput, weight, bias,
+ Conv2DConfig.builder()
+ .dataFormat(this.cnn2dDataFormat.name())
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .kH(kernelSize[0]).kW(kernelSize[1])
+ .sH(stride[0]).sW(stride[1])
+ .pH(padding[0]).pW(padding[1])
+ .dH(dilation[0]).dW(dilation[1])
+ .weightsFormat(WeightsFormat.OIYX)
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java
index b2f64c89450b..30a07c4ee5d6 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -30,11 +31,14 @@
import org.deeplearning4j.nn.params.DeconvolutionParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
/**
* 2D deconvolution layer configuration
@@ -108,6 +112,32 @@ public ParamInitializer initializer() {
return DeconvolutionParamInitializer.getInstance();
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ INDArray weight = params.get(DeconvolutionParamInitializer.WEIGHT_KEY);
+ params.put(DeconvolutionParamInitializer.WEIGHT_KEY, weight.permute(2, 3, 1, 0));
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(DeconvolutionParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(DeconvolutionParamInitializer.BIAS_KEY);
+
+ SDVariable value = sameDiff.cnn.deconv2d(layerInput, weight, bias,
+ DeConv2DConfig.builder()
+ .dataFormat(this.cnn2dDataFormat.name())
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .kH(kernelSize[0]).kW(kernelSize[1])
+ .sH(stride[0]).sW(stride[1])
+ .pH(padding[0]).pW(padding[1])
+ .dH(dilation[0]).dW(dilation[1])
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java
index 01bd3ca832c0..73111356644f 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -26,17 +27,19 @@
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
-import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer;
import org.deeplearning4j.nn.layers.convolution.Deconvolution3DLayer;
+import org.deeplearning4j.nn.params.Convolution3DParamInitializer;
import org.deeplearning4j.nn.params.Deconvolution3DParamInitializer;
-import org.deeplearning4j.nn.params.DeconvolutionParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
/**
* 3D deconvolution layer configuration
@@ -110,6 +113,26 @@ public ParamInitializer initializer() {
return Deconvolution3DParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(Convolution3DParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(Convolution3DParamInitializer.BIAS_KEY);
+
+ SDVariable value = sameDiff.cnn.deconv3d(layerInput, weight, bias,
+ DeConv3DConfig.builder()
+ .dataFormat(this.dataFormat.name())
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .kD(kernelSize[0]).kH(kernelSize[1]).kW(kernelSize[2])
+ .sD(stride[0]).sH(stride[1]).sW(stride[2])
+ .pD(padding[0]).pH(padding[1]).pW(padding[2])
+ .dD(dilation[0]).dH(dilation[1]).dW(dilation[2])
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java
index 67cac076d11b..b5d3061fef5d 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java
@@ -25,6 +25,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -72,6 +74,27 @@ public ParamInitializer initializer() {
return DefaultParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ SDVariable weight = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
+ // may be null
+ SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY);
+
+ SDVariable temp = layerInput.mmul(weight);
+
+ if(hasLayerNorm()){
+ SDVariable gain = paramTable.get(DefaultParamInitializer.GAIN_KEY);
+ temp = sameDiff.nn.layerNorm(temp, gain, bias, false, 1);
+ }
+
+ if(hasBias())
+ temp = temp.add(bias);
+
+ return doActivation(temp);
+ }
+
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
InputType outputType = getOutputType(-1, inputType);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java
index 7edf65618dd1..5fb77cd298e0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java
@@ -20,6 +20,7 @@
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.convolution.DepthwiseConvolution2DLayer;
@@ -27,11 +28,15 @@
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
+import org.nd4j.enums.WeightsFormat;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
/**
* 2D depth-wise convolution layer configuration.
@@ -89,6 +94,31 @@ public ParamInitializer initializer() {
return DepthwiseConvolutionParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(DepthwiseConvolutionParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(DepthwiseConvolutionParamInitializer.BIAS_KEY);
+
+ if(depthMultiplier != 1)
+ throw new UnsupportedOperationException("Can't convert depthwise convolutions wih a depth multiplier != 1");
+
+ //TODO can't set depthMultiplier?
+ SDVariable value = sameDiff.cnn.depthWiseConv2d(layerInput, weight, bias,
+ Conv2DConfig.builder()
+ .dataFormat(this.cnn2dDataFormat.name())
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .kH(kernelSize[0]).kW(kernelSize[1])
+ .sH(stride[0]).sW(stride[1])
+ .pH(padding[0]).pW(padding[1])
+ .dH(dilation[0]).dW(dilation[1])
+ .weightsFormat(WeightsFormat.OIYX)
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java
index 94cad0a9804f..e3310db142e0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java
@@ -27,6 +27,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.regularization.Regularization;
@@ -84,6 +86,12 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return doActivation(layerInput);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java
index 6478b6d59b67..4b1bb8550970 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java
@@ -23,13 +23,14 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
-import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -83,6 +84,25 @@ public ParamInitializer initializer() {
return EmbeddingLayerParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+// SDVariable weight = paramTable.get(EmbeddingLayerParamInitializer.WEIGHT_KEY);
+// SDVariable bias = paramTable.get(EmbeddingLayerParamInitializer.BIAS_KEY);
+//
+// TODO this cast causes a JVM crash
+// SDVariable indices = sameDiff.squeeze(layerInput, 1).castTo(DataType.INT64);
+//
+// System.out.println("Here!");
+// SDVariable out = sameDiff.gather(weight, indices, 1);
+//
+// if(hasBias)
+// out = out.add(bias);
+//
+// return doActivation(out);
+ throw new UnsupportedOperationException("Can't convert EmbeddingLayer to SameDiff");
+ }
+
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
//Basically a dense layer, but no dropout is possible here, and no epsilons
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java
index 1b76b6c7b7e4..d7e6b99c9431 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java
@@ -31,6 +31,8 @@
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java
index d9e10e6f5161..53b1c0caa121 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java
@@ -28,6 +28,8 @@
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -106,6 +108,7 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java
index ba0b52d8c555..7a01b4b9c223 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java
@@ -16,29 +16,52 @@
package org.deeplearning4j.nn.conf.layers;
-import lombok.*;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional.Mode;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.IActivation;
+import org.nd4j.linalg.activations.impl.ActivationELU;
+import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
+import org.nd4j.linalg.activations.impl.ActivationLReLU;
+import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
+import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
+import org.nd4j.linalg.activations.impl.ActivationSoftSign;
+import org.nd4j.linalg.activations.impl.ActivationTanH;
+import org.nd4j.linalg.activations.impl.ActivationThresholdedReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.NDArrayIndex;
/**
* LSTM recurrent neural network layer without peephole connections. Supports CuDNN acceleration - see https://deeplearning4j.konduit.ai/config/backends/config-cudnn for details
+ * href="https://deeplearning4j.konduit.ai/config/backends/config-cudnn">https://deeplearning4j.konduit.ai/config/backends/config-cudnn
+ * for details
*
* @author Alex Black
* @see GravesLSTM GravesLSTM class for an alternative LSTM (with peephole connections)
@@ -76,9 +99,10 @@ protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Bui
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners,
- int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
+ int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerValidation.assertNInNOutSet("LSTM", getLayerName(), layerIndex, getNIn(), getNOut());
- org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf, networkDataType);
+ org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf,
+ networkDataType);
ret.setListeners(trainingListeners);
ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView);
@@ -93,6 +117,141 @@ public ParamInitializer initializer() {
return LSTMParamInitializer.getInstance();
}
+ private static LSTMActivations toLSTMActivation(IActivation activationFn){
+ if(activationFn instanceof ActivationTanH)
+ return LSTMActivations.TANH;
+ else if(activationFn instanceof ActivationReLU) {
+ ActivationReLU relu = (ActivationReLU) activationFn;
+ if(relu.getThreshold() != 0 || relu.getNegativeSlope() != 0)
+ throw new UnsupportedOperationException("LSTM toSameDiff doesn't support ReLU activation with threshold and negative slope.");
+
+ if(relu.getMax() != 0)
+ throw new UnsupportedOperationException("LSTM toSameDiff doesn't support ReLU activation with max.");
+
+ //TODO no way to pass parms to libnd4j
+// if(relu.getNegativeSlope() != 0)
+// return LSTMActivations.LEAKY_RELU;
+//
+// if(relu.getThreshold() != 0)
+// return LSTMActivations.THRESHHOLD_RELU;
+
+ return LSTMActivations.RELU;
+ } else if(activationFn instanceof ActivationSigmoid)
+ return LSTMActivations.SIGMOID;
+ else if(activationFn instanceof ActivationLReLU)
+// return LSTMActivations.LEAKY_RELU;
+ //TODO no way to pass parms to libnd4j
+ throw new UnsupportedOperationException("LSTM toSameDiff doesn't support activation ActivationLReLU");
+ else if(activationFn instanceof ActivationThresholdedReLU)
+// return LSTMActivations.THRESHHOLD_RELU;
+ //TODO no way to pass parms to libnd4j
+ throw new UnsupportedOperationException("LSTM toSameDiff doesn't support activation ActivationThresholdedReLU");
+ else if(activationFn instanceof ActivationHardSigmoid)
+ return LSTMActivations.HARD_SIGMOID;
+ else if(activationFn instanceof ActivationELU)
+ return LSTMActivations.ELU;
+ else if(activationFn instanceof ActivationSoftSign)
+ return LSTMActivations.SOFTSIGN;
+ else if(activationFn instanceof ActivationSoftPlus)
+ return LSTMActivations.SOFTPLUS;
+ else
+ //TODO add ActivationThresholdedReLU and ActivationLReLU to list once supported
+ throw new UnsupportedOperationException("Unsupported activation for LSTM toSameDiff: " + activationFn.getClass().getSimpleName() +
+ ". Should be one of ActivationTanH, ActivationReLU, ActivationSigmoid, "
+ + "ActivationHardSigmoid, ActivationELU, ActivationSoftSign, or ActivationSoftPlus.");
+ }
+
+ /**
+ * Change weight from [input, forget, output, gate] (DL4J) to [input, forget, gate, output] (SameDiff)
+ */
+ private static INDArray changeWeight(INDArray weight){
+ int size = (int) (weight.size(1) / 4);
+ INDArray input = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(0, size));
+ INDArray forget = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2*size));
+ INDArray output = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(2*size, 3*size));
+ INDArray gate = weight.get(NDArrayIndex.all(), NDArrayIndex.interval(3*size, 4*size));
+ return Nd4j.concat(1, input, forget, gate, output);
+ }
+
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ INDArray bias = params.get(LSTMParamInitializer.BIAS_KEY);
+ params.put(LSTMParamInitializer.BIAS_KEY, Nd4j.squeeze(bias, 0));
+
+ params.put(LSTMParamInitializer.INPUT_WEIGHT_KEY,
+ changeWeight(params.get(LSTMParamInitializer.INPUT_WEIGHT_KEY)));
+
+ params.put(LSTMParamInitializer.RECURRENT_WEIGHT_KEY,
+ changeWeight(params.get(LSTMParamInitializer.RECURRENT_WEIGHT_KEY)));
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ SDVariable recurrentWeight = paramTable.get(LSTMParamInitializer.RECURRENT_WEIGHT_KEY);
+ SDVariable inputWeight = paramTable.get(LSTMParamInitializer.INPUT_WEIGHT_KEY);
+ SDVariable bias = paramTable.get(LSTMParamInitializer.BIAS_KEY);
+
+ LSTMActivations gateActivation = toLSTMActivation(gateActivationFn);
+ LSTMActivations recurrentActivation = toLSTMActivation(activationFn);
+
+
+ return sameDiff.rnn.lstmLayer(layerInput, LSTMLayerWeights.builder()
+ .weights(inputWeight)
+ .rWeights(recurrentWeight)
+ .bias(bias)
+ .build(),
+ LSTMLayerConfig.builder()
+ .gateAct(gateActivation)
+ .cellAct(recurrentActivation)
+ .outAct(recurrentActivation)
+ .retFullSequence(true)
+ .directionMode(LSTMDirectionMode.FWD)
+ .lstmdataformat(rnnDataFormat == RNNFormat.NCW ? LSTMDataFormat.NST : LSTMDataFormat.NTS)
+ .build())[0];
+ }
+
+ @Override
+ public SDVariable defineBidirectional(SameDiff sameDiff, SDVariable layerInput, Map paramTable,
+ SDVariable mask, Mode mode) {
+ //TODO need different param transforms for bidirectional
+// SDVariable recurrentWeight = paramTable.get(LSTMParamInitializer.RECURRENT_WEIGHT_KEY);
+// SDVariable inputWeight = paramTable.get(LSTMParamInitializer.INPUT_WEIGHT_KEY);
+// SDVariable bias = paramTable.get(LSTMParamInitializer.BIAS_KEY);
+//
+// LSTMActivations gateActivation = toLSTMActivation(gateActivationFn);
+// LSTMActivations recurrentActivation = toLSTMActivation(activationFn);
+//
+// LSTMDirectionMode directionMode;
+// if(mode == Mode.ADD || mode == Mode.AVERAGE)
+// directionMode = LSTMDirectionMode.BIDIR_SUM;
+// else if(mode == Mode.CONCAT)
+// directionMode = LSTMDirectionMode.BIDIR_CONCAT;
+// else
+// throw new UnsupportedOperationException("Bidirectional not supported for mode " + mode);
+//
+// LSTMDataFormat format = rnnDataFormat == RNNFormat.NCW ? LSTMDataFormat.NST : LSTMDataFormat.NTS;
+//
+// SDVariable output = sameDiff.rnn.lstmLayer(layerInput, LSTMLayerWeights.builder()
+// .weights(inputWeight)
+// .rWeights(recurrentWeight)
+// .bias(bias)
+// .build(),
+// LSTMLayerConfig.builder()
+// .gateAct(gateActivation)
+// .cellAct(recurrentActivation)
+// .directionMode(directionMode)
+// .lstmdataformat(format)
+// .build())[0];
+//
+// if(mode == Mode.AVERAGE)
+// return output.div(2);
+// else
+// return output;
+ return super.defineBidirectional(sameDiff, layerInput, paramTable, mask, mode);
+ }
+
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
//TODO - CuDNN etc
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java
index 25577bd1fad6..7f7c09030ed3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.TrainingConfig;
@@ -30,6 +31,8 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
@@ -97,6 +100,36 @@ protected void initializeConstraints(Builder> builder) {
this.iDropout = builder.iDropout;
}
+
+ /**
+ * Define the layer for SameDiff conversion.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException} like it does if not overridden.
+ *
+ * @param sameDiff SameDiff instance
+ * @param layerInput Input to the layer
+ * @param mask Optional, maybe null. Mask to apply if supported
+ * @param paramTable Parameter table - keys and shapes as defined in the layer implementation class.
+ * @return The final layer variable corresponding to the activations/output from the forward pass
+ */
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, SDVariable mask,
+ @NonNull Map paramTable){
+ throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
+
+ /**
+ * Do any necessary transforms to parameters (weights, biases, etc) before making SDVariables out of them.
+ * Useful for things like changing the dimension order or squeezing.
+ *
+ * Adding or removing parameters is supported.
+ *
+ * Should throw a {@link UnsupportedOperationException} if conversion of this layer configuration isn't
+ * supported and it will cause an error when transforming weights.
+ *
+ * @param params The parameters of the layer.
+ */
+ public void transformParamsForSameDiff(@NonNull Map params){
+ }
+
/**
* Reset the learning related configs of the layer to default. When instantiated with a global
* neural network configuration the parameters specified in the neural network configuration
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerWithLoss.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerWithLoss.java
new file mode 100644
index 000000000000..6259b69ec923
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerWithLoss.java
@@ -0,0 +1,42 @@
+/*
+ * ******************************************************************************
+ * * Copyright (c) 2020 Konduit K.K.
+ * *
+ * * This program and the accompanying materials are made available under the
+ * * terms of the Apache License, Version 2.0 which is available at
+ * * https://www.apache.org/licenses/LICENSE-2.0.
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * * License for the specific language governing permissions and limitations
+ * * under the License.
+ * *
+ * * SPDX-License-Identifier: Apache-2.0
+ * *****************************************************************************
+ */
+
+package org.deeplearning4j.nn.conf.layers;
+
+import lombok.NonNull;
+import org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+
+/**
+ * Any loss or output layers that support SameDiff conversion must implement this.
+ */
+public interface LayerWithLoss {
+
+ /**
+ * Define the layer's loss function. Should return a scalar.
+ *
+ * If average is true, should be the batchwise average, otherwise the sum.
+ * @param sameDiff The {@link SameDiff} instance
+ * @param input The input to the loss function, the output (activations) of this layer.
+ * @param labels The labels to compare the output to. The placeholder will be created with the shape of the output (activations) of this layer. May be null if the implementation layer doesn't require labels (e.g. {@link OCNNOutputLayer}.
+ * @param average Whether to average the loss per example. Most of the time this should be passed to the {@link org.nd4j.linalg.lossfunctions.ILossFunction}.
+ * @return The loss scalar.
+ */
+ SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels, boolean average);
+}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java
index 8ddd20b4558e..3d210675f1b8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java
@@ -24,7 +24,6 @@
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
-import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@@ -149,9 +148,10 @@ public void initializeParameters(Map params) {
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
val baseQueries = paramTable.get(WEIGHT_QUERIES);
- val batchSize = layerInput.shape().get(SDIndex.point(0));
+ val batchSize = sameDiff.sizeAt(layerInput, 0);
val tileAxis = sameDiff.scatterUpdate(sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
val queries = sameDiff.tile(baseQueries, tileAxis);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java
index ebfc56a7b1ae..5bf17f72776d 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java
@@ -27,9 +27,12 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.Collection;
@@ -90,6 +93,25 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ if(dataFormat == CNN2DFormat.NHWC)
+ layerInput = layerInput.permute(0, 3, 1, 2);
+ //TODO support more data types
+
+ SDVariable output = sameDiff.cnn.localResponseNormalization(layerInput, LocalResponseNormalizationConfig.builder()
+ .alpha(alpha)
+ .beta(beta) //TODO n and k map to bias and depth? guessing based on the paper but data types don't line up
+ .bias(k)
+ .depth((int) n)
+ .build());
+
+ if(dataFormat == CNN2DFormat.NHWC)
+ output = output.permute(0, 2, 3, 1);
+
+ return output;
+ }
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
@@ -274,7 +296,7 @@ public Builder helperAllowFallback(boolean allowFallback) {
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.
* Default: NCHW
- * @param format Format for activations (in and out)
+ * @param dataFormat Format for activations (in and out)
*/
public Builder dataFormat(CNN2DFormat dataFormat){
this.dataFormat = dataFormat;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java
index a5028f08a4f0..c829abb6d65b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java
@@ -169,7 +169,8 @@ public void initializeParameters(Map params) {
}
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); // (outH, featureDim, nOut)
int outH = outputSize;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java
index b65d2fe77c72..e86b0ad3f74b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java
@@ -173,7 +173,8 @@ public void initializeParameters(Map params) {
}
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java
index e26dbd83fea1..269a17d8caea 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -28,6 +29,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -49,7 +52,7 @@
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
-public class LossLayer extends FeedForwardLayer {
+public class LossLayer extends FeedForwardLayer implements LayerWithLoss {
protected ILossFunction lossFn;
@@ -90,6 +93,18 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return doActivation(layerInput);
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, labels, average);
+ }
+
public static class Builder extends BaseOutputLayer.Builder {
public Builder() {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java
index 75d86460598d..3a90105a6e8e 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java
@@ -19,12 +19,15 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -66,6 +69,28 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection paramTable) {
+
+ SDVariable weight = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
+ // may be null
+ SDVariable bias = paramTable.get(DefaultParamInitializer.BIAS_KEY);
+
+ SDVariable temp = layerInput.mmul(weight);
+
+ if(hasBias())
+ temp = temp.add(bias);
+
+ return doActivation(temp);
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, labels, average);
+ }
+
@Override
public ParamInitializer initializer() {
return DefaultParamInitializer.getInstance();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java
index b9b8bca4de49..c432b5a0f081 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java
@@ -27,6 +27,9 @@
import org.deeplearning4j.nn.params.PReLUParamInitializer;
import org.deeplearning4j.nn.weights.WeightInitConstant;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -74,6 +77,13 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection paramTable) {
+ SDVariable alpha = paramTable.get(PReLUParamInitializer.WEIGHT_KEY);
+ return doActivation(sameDiff.nn.prelu(layerInput, alpha, ArrayUtil.toInts(sharedAxes)));
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java
index ff4e1cf76b55..3d2cc9e7d704 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java
@@ -101,7 +101,7 @@ public PrimaryCapsules(Builder builder){
}
@Override
- public SDVariable defineLayer(SameDiff SD, SDVariable input, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff SD, SDVariable input, SDVariable mask, Map paramTable) {
Conv2DConfig conf = Conv2DConfig.builder()
.kH(kernelSize[0]).kW(kernelSize[1])
.sH(stride[0]).sW(stride[1])
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java
index 10659f326154..4e0d4f228070 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java
@@ -181,7 +181,8 @@ public void validateInput(INDArray input) {
}
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
final val W = paramTable.get(WEIGHT_KEY);
final val R = paramTable.get(RECURRENT_WEIGHT_KEY);
final val b = paramTable.get(BIAS_KEY);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java
index f1dcd73a615b..05248153b5dc 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -30,8 +31,11 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@@ -53,7 +57,7 @@
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
-public class RnnLossLayer extends FeedForwardLayer {
+public class RnnLossLayer extends FeedForwardLayer implements LayerWithLoss {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
protected ILossFunction lossFn;
@@ -82,6 +86,43 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ SDVariable batch = sameDiff.sizeAt(layerInput, 0);
+ SDVariable channels;
+ SDVariable size;
+
+ if(rnnDataFormat == RNNFormat.NCW){
+ channels = sameDiff.sizeAt(layerInput, 1);
+ size = sameDiff.sizeAt(layerInput, 2);
+ layerInput = layerInput.permute(0, 2, 1);
+ } else if(rnnDataFormat == RNNFormat.NWC){
+ size = sameDiff.sizeAt(layerInput, 1);
+ channels = sameDiff.sizeAt(layerInput, 2);
+ } else
+ throw new UnsupportedOperationException("Unknown CNN data format " + rnnDataFormat);
+
+ SDVariable distributedInput = layerInput.reshape(sameDiff.concat(0, sameDiff.constant(
+ Nd4j.scalar(batch.dataType(), -1)), channels));
+
+ SDVariable distributedOutput = doActivation(distributedInput);
+
+ SDVariable output = distributedOutput.reshape(sameDiff.concat(0, batch, size, channels));
+
+ if(rnnDataFormat == RNNFormat.NCW)
+ return output.permute(0, 2, 1);
+ else
+ return output;
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, labels, average);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java
index cfd337514a43..05d0627c82d8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
@@ -28,9 +29,12 @@
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
@@ -80,6 +84,48 @@ public ParamInitializer initializer() {
return DefaultParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable b = paramTable.get(DefaultParamInitializer.BIAS_KEY);
+ SDVariable W = paramTable.get(DefaultParamInitializer.WEIGHT_KEY);
+
+ SDVariable batch = sameDiff.sizeAt(layerInput, 0);
+ SDVariable sequenceLength;
+
+ SDVariable neg1 = sameDiff.constant(Nd4j.scalar(batch.dataType(), -1));
+
+ if(rnnDataFormat == RNNFormat.NCW) {
+ sequenceLength = sameDiff.sizeAt(layerInput, 2);
+ layerInput = layerInput.permute(0, 2, 1);
+ } else if(rnnDataFormat == RNNFormat.NWC)
+ sequenceLength = sameDiff.sizeAt(layerInput, 1);
+ else
+ throw new UnsupportedOperationException("Unknown RNN data format " + rnnDataFormat);
+
+ SDVariable distributedShape = sameDiff.concat(0, batch.mul(sequenceLength), neg1);
+ SDVariable distributedInput = layerInput.reshape(distributedShape);
+
+ SDVariable distributedOutput = distributedInput.mmul(W);
+ if(hasBias)
+ distributedOutput = distributedOutput.add(b);
+
+ distributedOutput = doActivation(distributedOutput);
+
+ SDVariable temp = distributedOutput.reshape(sameDiff.concat(0, batch, sequenceLength, neg1));
+
+ if(rnnDataFormat == RNNFormat.NCW)
+ return temp.permute(0, 2, 1);
+ else
+ return temp;
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, labels, average);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java
index 79fa765a4984..ca5737b04eaf 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java
@@ -136,7 +136,8 @@ public void initializeParameters(Map params) {
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
if(projectInput){
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java
index f9ae11b4936e..eec5431bc34a 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java
@@ -21,6 +21,7 @@
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
+import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer;
@@ -28,10 +29,14 @@
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.enums.WeightsFormat;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
/**
* 2D Separable convolution layer configuration.
@@ -149,6 +154,28 @@ public ParamInitializer initializer() {
return SeparableConvolutionParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable depthWeight = paramTable.get(SeparableConvolutionParamInitializer.DEPTH_WISE_WEIGHT_KEY);
+ SDVariable pointWeight = paramTable.get(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY);
+ SDVariable bias = paramTable.get(SeparableConvolutionParamInitializer.BIAS_KEY);
+
+ SDVariable value = sameDiff.cnn.separableConv2d(layerInput, depthWeight, pointWeight, bias,
+ Conv2DConfig.builder()
+ .dataFormat(this.cnn2dDataFormat.name())
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .kH(kernelSize[0]).kW(kernelSize[1])
+ .sH(stride[0]).sW(stride[1])
+ .pH(padding[0]).pW(padding[1])
+ .dH(dilation[0]).dW(dilation[1])
+ .weightsFormat(WeightsFormat.OIYX)
+ .build()
+ );
+
+ return doActivation(value);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java
index 042f09121a6d..98d676f4a838 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java
@@ -27,6 +27,8 @@
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -124,6 +126,13 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ //TODO SameDiff spaceToBatch has issues, see https://github.com/eclipse/deeplearning4j/issues/9019
+ throw new UnsupportedOperationException("Can't convert SpaceToBatchLayer to SameDiff");
+// return sameDiff.cnn.spaceToBatch(layerInput, blocks, padding[0], padding[1]);
+ }
@Override
public void setNIn(InputType inputType, boolean override) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java
index 53d9007be47b..85f7b672d279 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java
@@ -26,6 +26,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -125,6 +127,19 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ org.nd4j.enums.DataFormat format;
+ if(dataFormat == CNN2DFormat.NCHW)
+ format = org.nd4j.enums.DataFormat.NCHW;
+ else if(dataFormat == CNN2DFormat.NHWC)
+ format = org.nd4j.enums.DataFormat.NHWC;
+ else
+ throw new UnsupportedOperationException("Unknown CNN data format " + dataFormat);
+
+ return sameDiff.cnn.spaceToDepth(layerInput, blockSize, format);
+ }
@Override
public void setNIn(InputType inputType, boolean override) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java
index 5d2a55994e80..3b082c6ac561 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java
@@ -19,6 +19,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -28,6 +29,8 @@
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -75,6 +78,16 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ layerInput = sameDiff.expandDims(layerInput, -1);
+
+ SDVariable out = super.defineLayer(sameDiff, layerInput, mask, paramTable);
+ return sameDiff.squeeze(out, -1);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java
index 67a2e804c1aa..579bda2b846e 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java
@@ -22,6 +22,7 @@
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
@@ -29,9 +30,12 @@
import org.deeplearning4j.util.Convolution3DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.learning.regularization.Regularization;
@@ -132,6 +136,27 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ Pooling3DConfig poolingConfig = Pooling3DConfig.builder()
+ .kD(kernelSize[0]).kH(kernelSize[1]).kW(kernelSize[2])
+ .sD(stride[0]).sH(stride[1]).sW(stride[2])
+ .pD(padding[0]).pH(padding[1]).pW(padding[2])
+ .dD(dilation[0]).dH(dilation[1]).dW(dilation[2])
+ .isNCDHW(dataFormat == DataFormat.NCDHW)
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .build();
+
+ if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.MAX){
+ return sameDiff.cnn.maxPooling3d(layerInput, poolingConfig);
+ } else if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.AVG){
+ return sameDiff.cnn.avgPooling3d(layerInput, poolingConfig);
+ } else {
+ throw new UnsupportedOperationException("Can't convert " + poolingType + " pooling layer to SameDiff, only MAX and AVG supported");
+ }
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java
index a434d05bccb6..1f427af4fbb6 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java
@@ -29,12 +29,15 @@
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
/**
* Subsampling layer also referred to as pooling in convolution neural nets
@@ -147,6 +150,28 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ Pooling2DConfig poolingConfig = Pooling2DConfig.builder()
+ .kH(kernelSize[0]).kW(kernelSize[1])
+ .sH(stride[0]).sW(stride[1])
+ .pH(padding[0]).pW(padding[1])
+ .dH(dilation[0]).dW(dilation[1])
+ .isNHWC(cnn2dDataFormat == CNN2DFormat.NHWC)
+ .isSameMode(convolutionMode == ConvolutionMode.Same)
+ .build();
+
+ if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.MAX){
+ return sameDiff.cnn.maxPooling2d(layerInput, poolingConfig);
+ } else if(poolingType == org.deeplearning4j.nn.conf.layers.PoolingType.AVG){
+ return sameDiff.cnn.avgPooling2d(layerInput, poolingConfig);
+ } else {
+ throw new UnsupportedOperationException("Can't convert " + poolingType + " pooling layer to SameDiff, only MAX and AVG supported");
+ }
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java
index 4b39fa34d22a..a258486c0c18 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java
@@ -20,6 +20,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.ToString;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -28,6 +29,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -84,6 +87,12 @@ public Upsampling1D clone() {
return clone;
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return sameDiff.squeeze(sameDiff.cnn.upsampling2d(sameDiff.expandDims(layerInput, -1), size[0], 1, true), -1);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java
index 0357c3e7bab9..5361e932f3cc 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java
@@ -26,6 +26,8 @@
import org.deeplearning4j.nn.conf.serde.legacy.LegacyIntArrayDeserializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
@@ -89,6 +91,12 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return sameDiff.cnn.upsampling2d(layerInput, size[0], size[1], format == CNN2DFormat.NCHW);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java
index 695212d89d3a..4808d33a9736 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java
@@ -20,10 +20,13 @@
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -91,6 +94,12 @@ public InputType getOutputType(int layerIndex, InputType inputType) {
return InputType.convolutional3D(size[0] * inDepth, size[1] * inHeight, size[2] * inWidth, inChannels);
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return sameDiff.cnn.upsampling3d(layerInput, dataFormat == DataFormat.NCDHW, size[0], size[1], size[2]);
+ }
+
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java
index a3345fde9761..3678d605688a 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java
@@ -27,12 +27,15 @@
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.factory.Nd4j;
/**
* Zero padding 1D layer for convolutional neural networks. Allows padding to be done separately for top and bottom.
@@ -82,6 +85,23 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ int padLeft = padding[0];
+ int padRight = padding[1];
+
+ //TODO support data formats
+ int[][] fullPadding = new int[][]{
+ {0, 0},
+ {0, 0},
+ {padLeft, padRight}
+ };
+
+ return sameDiff.nn.pad(layerInput, sameDiff.constant(Nd4j.createFromArray(fullPadding)), 0);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java
index 8dfd594a6ace..a9578eaefbd3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java
@@ -26,12 +26,15 @@
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.factory.Nd4j;
/**
* Zero padding 3D layer for convolutional neural networks. Allows padding to be done separately for "left" and "right"
@@ -70,6 +73,30 @@ public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ //TODO support data formats
+ int padLeftD = padding[0];
+ int padRightD = padding[1];
+ int padLeftH = padding[2];
+ int padRightH = padding[3];
+ int padLeftW = padding[4];
+ int padRightW = padding[5];
+
+ int[][] fullPadding;
+ fullPadding = new int[][]{
+ {0, 0},
+ {0, 0},
+ {padLeftD, padRightD},
+ {padLeftH, padRightH},
+ {padLeftW, padRightW}
+ };
+
+ return sameDiff.nn.pad(layerInput, sameDiff.constant(Nd4j.createFromArray(fullPadding)), 0);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java
index ef92d2f8588a..d15a41b6c500 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java
@@ -26,6 +26,8 @@
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -33,6 +35,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.factory.Nd4j;
/**
* Zero padding layer for convolutional neural networks (2D CNNs). Allows padding to be done separately for
@@ -82,6 +85,37 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ int padTop = padding[0];
+ int padBottom = padding[1];
+ int padLeft = padding[2];
+ int padRight = padding[3];
+
+ int[][] fullPadding;
+ if(dataFormat == CNN2DFormat.NCHW){
+ fullPadding = new int[][]{
+ {0, 0},
+ {0, 0},
+ {padTop, padBottom},
+ {padLeft, padRight}
+ };
+ } else if(dataFormat == CNN2DFormat.NHWC) {
+ fullPadding = new int[][]{
+ {0, 0},
+ {padTop, padBottom},
+ {padLeft, padRight},
+ {0, 0}
+ };
+ } else {
+ throw new UnsupportedOperationException("Unknown CNN data format " + dataFormat);
+ }
+
+ return sameDiff.nn.pad(layerInput, sameDiff.constant(Nd4j.createFromArray(fullPadding)), 0);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
int[] hwd = ConvolutionUtils.getHWDFromInputType(inputType);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java
index b10b0e716b73..d0f04a95100e 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java
@@ -27,6 +27,9 @@
import org.deeplearning4j.nn.layers.convolution.Cropping1DLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -86,6 +89,19 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ private static Integer end(int idx){
+ if(idx == 0)
+ return null;
+ else
+ return -idx;
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ return layerInput.get(SDIndex.all(), SDIndex.all(), SDIndex.interval(cropping[0], end(cropping[1])));
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java
index 3a13e2fc052d..1a98918142df 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java
@@ -29,6 +29,9 @@
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -103,6 +106,25 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ private static Integer end(int idx){
+ if(idx == 0)
+ return null;
+ else
+ return -idx;
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ if(dataFormat == CNN2DFormat.NCHW) {
+ return layerInput.get(SDIndex.all(), SDIndex.all(), SDIndex.interval(cropping[0], end(cropping[1])), SDIndex.interval(cropping[2], end(cropping[3])));
+ } else if(dataFormat == CNN2DFormat.NHWC){
+ return layerInput.get(SDIndex.all(), SDIndex.interval(cropping[0], end(cropping[1])), SDIndex.interval(cropping[2], end(cropping[3])), SDIndex.all());
+ } else {
+ throw new UnsupportedOperationException("Unknown CNN data format " + dataFormat);
+ }
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
int[] hwd = ConvolutionUtils.getHWDFromInputType(inputType);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java
index 74710a4699e1..d1f9c8630529 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java
@@ -27,6 +27,9 @@
import org.deeplearning4j.nn.layers.convolution.Cropping3DLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -95,6 +98,23 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ private static Integer end(int idx){
+ if(idx == 0)
+ return null;
+ else
+ return -idx;
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ //TODO support different dataTypes
+ return layerInput.get(SDIndex.all(), SDIndex.all(),
+ SDIndex.interval(cropping[0], end(cropping[1])),
+ SDIndex.interval(cropping[2], end(cropping[3])),
+ SDIndex.interval(cropping[4], end(cropping[5])));
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java
index ef88dc8b7ac2..1faf30e8269d 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java
@@ -26,6 +26,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.ElementWiseParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -83,6 +85,17 @@ public ParamInitializer initializer() {
return ElementWiseParamInitializer.getInstance();
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable weight = paramTable.get(ElementWiseParamInitializer.WEIGHT_KEY);
+ SDVariable bias = paramTable.get(ElementWiseParamInitializer.BIAS_KEY);
+
+ SDVariable out = layerInput.mul(weight).add(bias);
+
+ return doActivation(out);
+ }
+
/**
* This is a report of the estimated memory consumption for the given layer
*
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java
index f72da09e592a..fcebe4c65fc3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java
@@ -16,8 +16,10 @@
package org.deeplearning4j.nn.conf.layers.misc;
+import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
+import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
@@ -29,12 +31,14 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.FrozenLayerParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.NameScope;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.shade.jackson.annotation.JsonProperty;
-import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import java.util.Collection;
import java.util.List;
@@ -102,6 +106,30 @@ public ParamInitializer initializer() {
return FrozenLayerParamInitializer.getInstance();
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ layer.transformParamsForSameDiff(params);
+ }
+
+ /**
+ * Will freeze any params passed to it.
+ * @param sameDiff SameDiff instance
+ * @param layerInput Input to the layer
+ * @param mask Optional, maybe null. Mask to apply if supported
+ * @param paramTable Parameter table - keys and shapes as defined in the layer implementation class.
+ */
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ for(SDVariable variable : paramTable.values()){
+ variable.convertToConstant();
+ }
+ NameScope underlyingScope = sameDiff.withNameScope("underlying");
+ SDVariable output = layer.defineLayer(sameDiff, layerInput, mask, paramTable);
+ underlyingScope.close();
+ return output;
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
return layer.getOutputType(layerIndex, inputType);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java
index 468c310329b7..9bd6e522cf0f 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java
@@ -16,7 +16,9 @@
package org.deeplearning4j.nn.conf.layers.misc;
+import java.util.Map;
import lombok.Data;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -24,6 +26,8 @@
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.params.FrozenLayerWithBackpropParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
@@ -83,6 +87,22 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return new org.deeplearning4j.nn.layers.FrozenLayerWithBackprop(underlying);
}
+ /**
+ * Will freeze any params passed to it.
+ * @param sameDiff SameDiff instance
+ * @param layerInput Input to the layer
+ * @param mask Optional, maybe null. Mask to apply if supported
+ * @param paramTable Parameter table - keys and shapes as defined in the layer implementation class.
+ */
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ for(SDVariable variable : paramTable.values()){
+ variable.convertToConstant();
+ }
+ return defineUnderlying(sameDiff, layerInput, mask, paramTable);
+ }
+
@Override
public ParamInitializer initializer() {
return FrozenLayerWithBackpropParamInitializer.getInstance();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java
index e7f252c4ba50..1a6b2ed4c5bb 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java
@@ -26,6 +26,8 @@
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -79,6 +81,24 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
return ret;
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ layerInput = sameDiff.expandDims(layerInput, -1); // [batch, size, 1]
+ SDVariable out;
+ out = sameDiff.tile(layerInput, 1, 1, n); // [batch, size, n]
+
+ //noinspection StatementWithEmptyBody
+ if(dataFormat == RNNFormat.NCW){
+ } else if(dataFormat == RNNFormat.NWC) {
+ out = out.permute(0, 2, 1);
+ } else {
+ throw new UnsupportedOperationException("Unknown RNN data format " + dataFormat);
+ }
+
+ return doActivation(out);
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.FF) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java
index 792e5633b36c..24a79ca759f7 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf.layers.recurrent;
+import java.util.HashMap;
import lombok.*;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
@@ -32,6 +33,8 @@
import org.deeplearning4j.nn.params.BidirectionalParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.TimeSeriesUtils;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
@@ -43,7 +46,6 @@
import java.util.Map;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
-import static org.nd4j.linalg.indexing.NDArrayIndex.point;
/**
* Bidirectional is a "wrapper" layer: it wraps any uni-directional RNN layer to make it bidirectional.
Note that
@@ -80,6 +82,7 @@ public enum Mode {
private transient BidirectionalParamInitializer initializer;
private Bidirectional(Bidirectional.Builder builder) {
+ //TODO builder params aren't used?
super(builder);
}
@@ -110,6 +113,92 @@ public Bidirectional(@NonNull Mode mode, @NonNull Layer layer) {
this.mode = mode;
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ Map fwdParams = new HashMap<>();
+ Map bwdParams = new HashMap<>();
+
+ for(String key : params.keySet()){
+ if(key.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){
+ fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.FORWARD_PREFIX, ""), params.get(key));
+ } else if(key.startsWith(BidirectionalParamInitializer.BACKWARD_PREFIX)){
+ fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.BACKWARD_PREFIX, ""), params.get(key));
+ }
+ }
+
+ fwd.transformParamsForSameDiff(fwdParams);
+ bwd.transformParamsForSameDiff(bwdParams);
+
+ params.clear();
+ for(Map.Entry entry : fwdParams.entrySet()){
+ params.put(BidirectionalParamInitializer.FORWARD_PREFIX + entry.getKey(), entry.getValue());
+ }
+ for(Map.Entry entry : bwdParams.entrySet()){
+ params.put(BidirectionalParamInitializer.BACKWARD_PREFIX + entry.getKey(), entry.getValue());
+ }
+ }
+
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ Map fwdParams = new HashMap<>();
+ Map bwdParams = new HashMap<>();
+
+ for(String key : paramTable.keySet()){
+ if(key.startsWith(BidirectionalParamInitializer.FORWARD_PREFIX)){
+ fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.FORWARD_PREFIX, ""), paramTable.get(key));
+ } else if(key.startsWith(BidirectionalParamInitializer.BACKWARD_PREFIX)){
+ fwdParams.put(key.replaceFirst(BidirectionalParamInitializer.BACKWARD_PREFIX, ""), paramTable.get(key));
+ }
+ }
+
+ if(fwd instanceof BaseRecurrentLayer){
+
+ try{
+ return ((BaseRecurrentLayer) fwd).defineBidirectional(sameDiff, layerInput, paramTable, mask, mode);
+ } catch (UnsupportedOperationException e) {
+
+
+ SDVariable fwdOut = ((BaseRecurrentLayer) fwd)
+ .defineLayer(sameDiff, layerInput, fwdParams, mask, false);
+ SDVariable bwdOut = ((BaseRecurrentLayer) fwd)
+ .defineLayer(sameDiff, layerInput, bwdParams, mask, true);
+
+
+ bwdOut = sameDiff.reverse(bwdOut, ((BaseRecurrentLayer) fwd).getRnnDataFormat() == RNNFormat.NCW ? 2 : 1);
+
+ if(mode == Mode.CONCAT) {
+ if(((BaseRecurrentLayer) fwd).getRnnDataFormat() == RNNFormat.NCW)
+ return sameDiff.concat(1, fwdOut, bwdOut);
+ else
+ return sameDiff.concat(2, fwdOut, bwdOut);
+ } else if(mode == Mode.ADD)
+ return fwdOut.add(bwdOut);
+ else if(mode == Mode.AVERAGE)
+ return fwdOut.add(bwdOut).div(2);
+ else if(mode == Mode.MUL)
+ return fwdOut.mul(bwdOut);
+ else
+ throw new UnsupportedOperationException("Unknown bidirectional mode " + mode);
+ }
+ } else if(fwd instanceof LastTimeStep){
+ SDVariable fwdOut = fwd.defineLayer(sameDiff, layerInput, mask, fwdParams);
+ SDVariable bwdOut = bwd.defineLayer(sameDiff, layerInput, mask, bwdParams);
+ if(mode == Mode.CONCAT) {
+ return sameDiff.concat(1, fwdOut, bwdOut);
+ } else if(mode == Mode.ADD)
+ return fwdOut.add(bwdOut);
+ else if(mode == Mode.AVERAGE)
+ return fwdOut.add(bwdOut).div(2);
+ else if(mode == Mode.MUL)
+ return fwdOut.mul(bwdOut);
+ else
+ throw new UnsupportedOperationException("Unknown bidirectional mode " + mode);
+ } else
+ throw new UnsupportedOperationException("Bidirectional toSameDiff doesn't support layer " + fwd.getClass().getSimpleName());
+ }
+
public long getNOut() {
if (this.fwd instanceof LastTimeStep) {
return ((FeedForwardLayer) ((LastTimeStep) this.fwd).getUnderlying()).getNOut();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java
index 52c048472f27..5ae8f6844b11 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java
@@ -16,12 +16,17 @@
package org.deeplearning4j.nn.conf.layers.recurrent;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -60,6 +65,13 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
initializeParams, networkDataType));
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable underlyingOutput = defineUnderlying(sameDiff, layerInput, mask, paramTable);
+ return underlyingOutput.get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java
index 7bc91c17eb00..4bce1dd45bd0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java
@@ -19,21 +19,27 @@
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.SimpleRnnParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
+import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
/**
* Simple RNN - aka "vanilla" RNN is the simplest type of recurrent neural network layer. It implements {@code out_t =
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java
index 5489ccc78d0e..8887e7ddbff8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java
@@ -1,5 +1,6 @@
package org.deeplearning4j.nn.conf.layers.recurrent;
+import java.util.Map;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
@@ -11,8 +12,12 @@
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Collection;
@@ -53,6 +58,36 @@ public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
initializeParams, networkDataType), rnnDataFormat);
}
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable originalShape = layerInput.shape();
+ SDVariable batch = originalShape.get(SDIndex.point(0));
+ SDVariable sequenceLength;
+
+ SDVariable neg1 = sameDiff.constant(Nd4j.scalar(batch.dataType(), -1));
+
+ if(rnnDataFormat == RNNFormat.NCW) {
+ sequenceLength = originalShape.get(SDIndex.point(2));
+ layerInput = layerInput.permute(0, 2, 1);
+ } else if(rnnDataFormat == RNNFormat.NWC)
+ sequenceLength = originalShape.get(SDIndex.point(1));
+ else
+ throw new UnsupportedOperationException("Unknown RNN data format " + rnnDataFormat);
+
+ SDVariable distributedShape = sameDiff.concat(0, batch.mul(sequenceLength), neg1);
+ SDVariable distributedInput = layerInput.reshape(distributedShape);
+
+ SDVariable distributedOutput = defineUnderlying(sameDiff, distributedInput, mask, paramTable);
+
+ SDVariable temp = distributedOutput.reshape(sameDiff.concat(0, batch, sequenceLength, neg1));
+
+ if(rnnDataFormat == RNNFormat.NCW)
+ return temp.permute(0, 2, 1);
+ else
+ return temp;
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType.getType() != InputType.Type.RNN) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java
index 7271d59efaad..d4800dc174b8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java
@@ -42,7 +42,8 @@ public abstract class SameDiffLambdaLayer extends SameDiffLayer {
public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput);
@Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map paramTable, SDVariable mask) {
+ public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable) {
return defineLayer(sameDiff, layerInput);
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java
index f290a09f36f1..5452471363f0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java
@@ -74,12 +74,12 @@ protected SameDiffLayer() {
*
* @param sameDiff SameDiff instance
* @param layerInput Input to the layer
- * @param paramTable Parameter table - keys as defined by {@link #defineParameters(SDLayerParams)}
* @param mask Optional, maybe null. Mask to apply if supported
+ * @param paramTable Parameter table - keys as defined by {@link #defineParameters(SDLayerParams)}
* @return The final layer variable corresponding to the activations/output from the forward pass
*/
- public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput,
- Map paramTable, SDVariable mask);
+ public abstract SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable);
/**
* @see Layer#feedForwardMaskArray(INDArray, MaskState, int)
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java
index d6c4892f37fd..6e779485c081 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.conf.layers.samediff;
+import lombok.NonNull;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.autodiff.samediff.SDVariable;
@@ -63,9 +64,15 @@ protected SameDiffOutputLayer() {
* @param paramTable Parameter table - keys as defined by {@link #defineParameters(SDLayerParams)}
* @return The final layer variable corresponding to the score/loss during forward pass. This must be a single scalar value.
*/
- public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, SDVariable labels,
+ public abstract SDVariable defineLayerAndLoss(SameDiff sameDiff, SDVariable layerInput, SDVariable labels,
Map paramTable);
+ @Override
+ public SDVariable defineLayer(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput,
+ SDVariable mask, @NonNull Map paramTable) {
+ throw new IllegalStateException("SameDiffOutputLayers should be defined using the define method using labels");
+ }
+
/**
* Output layers should terminate in a single scalar value (i.e., a score) - however, sometimes the output activations
* (such as softmax probabilities) need to be returned. When this is the case, we need to know the name of the
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java
index fd9c64ba6643..e0d89a98eb00 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java
@@ -16,9 +16,11 @@
package org.deeplearning4j.nn.conf.layers.util;
+import java.util.Map;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -27,6 +29,8 @@
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java
index c11f618dafee..ee987f2e6ef3 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java
@@ -16,7 +16,9 @@
package org.deeplearning4j.nn.conf.layers.wrapper;
+import java.util.Map;
import lombok.Data;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
@@ -24,6 +26,10 @@
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.params.WrapperLayerParamInitializer;
+import org.nd4j.autodiff.samediff.NameScope;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.List;
@@ -54,6 +60,19 @@ public ParamInitializer initializer() {
return WrapperLayerParamInitializer.getInstance();
}
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params) {
+ underlying.transformParamsForSameDiff(params);
+ }
+
+ protected SDVariable defineUnderlying(SameDiff sameDiff, SDVariable layerInput, SDVariable mask,
+ Map paramTable){
+ NameScope underlyingScope = sameDiff.withNameScope("underlying");
+ SDVariable output = underlying.defineLayer(sameDiff, layerInput, mask, paramTable);
+ underlyingScope.close();
+ return output;
+ }
+
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
return underlying.getOutputType(layerIndex, inputType);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java
index 539289ecafbd..2fd04496064c 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java
@@ -24,6 +24,8 @@
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
@@ -124,6 +126,25 @@ public Layer instantiate(NeuralNetConfiguration conf, Collection paramTable) {
+ SDVariable w = paramTable.get(OCNNParamInitializer.W_KEY);
+ SDVariable v = paramTable.get(OCNNParamInitializer.V_KEY);
+
+ SDVariable wFlat = w.reshape(sameDiff.concat(0, sameDiff.sizeAt(w, 0), sameDiff.constant(-1)));
+
+ SDVariable first = layerInput.mul(v);
+ SDVariable act2d = doActivation(first);
+ return act2d.mul(wFlat); //TODO DL4J implementation sets labels to the output as well, will this work here? probably not
+ }
+
+ @Override
+ public SDVariable defineLoss(@NonNull SameDiff sameDiff, @NonNull SDVariable input, SDVariable labels,
+ boolean average) {
+ return lossFn.defineLoss(sameDiff, input, input, average);
+ }
+
@Override
public long getNOut() {
//we don't change number of outputs here
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java
index 83f1097b7820..5203cca260bd 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/BaseInputPreProcessor.java
@@ -16,8 +16,11 @@
package org.deeplearning4j.nn.conf.preprocessor;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@@ -43,4 +46,8 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt
//Default: pass-through, unmodified
return new Pair<>(maskArray, currentMaskState);
}
+
+ public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input){
+ throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java
index f6ba7af7b344..0041fadfbc07 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/Cnn3DToFeedForwardPreProcessor.java
@@ -49,7 +49,7 @@
* @see FeedForwardToCnn3DPreProcessor for opposite case (i.e., DenseLayer -> CNN3D)
*/
@Data
-public class Cnn3DToFeedForwardPreProcessor implements InputPreProcessor {
+public class Cnn3DToFeedForwardPreProcessor extends BaseInputPreProcessor {
protected long inputDepth;
protected long inputHeight;
protected long inputWidth;
@@ -141,16 +141,6 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr
}
- @Override
- public Cnn3DToFeedForwardPreProcessor clone() {
- try {
- Cnn3DToFeedForwardPreProcessor clone = (Cnn3DToFeedForwardPreProcessor) super.clone();
- return clone;
- } catch (CloneNotSupportedException e) {
- throw new RuntimeException(e);
- }
- }
-
@Override
public InputType getOutputType(InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java
index 1a7e3928b9be..0779d4c1786e 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java
@@ -17,11 +17,14 @@
package org.deeplearning4j.nn.conf.preprocessor;
import lombok.Data;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.common.primitives.Pair;
@@ -51,7 +54,7 @@
* @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc)
*/
@Data
-public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+public class CnnToFeedForwardPreProcessor extends BaseInputPreProcessor {
protected long inputHeight;
protected long inputWidth;
protected long numChannels;
@@ -157,16 +160,6 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace
}
- @Override
- public CnnToFeedForwardPreProcessor clone() {
- try {
- CnnToFeedForwardPreProcessor clone = (CnnToFeedForwardPreProcessor) super.clone();
- return clone;
- } catch (CloneNotSupportedException e) {
- throw new RuntimeException(e);
- }
- }
-
@Override
public InputType getOutputType(InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
@@ -195,4 +188,9 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt
return new Pair<>(maskArray.reshape(maskArray.ordering(), maskArray.size(0), maskArray.size(1)), currentMaskState);
}
+
+ @Override
+ public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
+ return input.reshape(-1, numChannels * inputHeight * inputWidth);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java
index 6f18e70e4e88..de4b5c5a2338 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java
@@ -48,7 +48,7 @@
*/
@Data
@EqualsAndHashCode(exclude = {"product"})
-public class CnnToRnnPreProcessor implements InputPreProcessor {
+public class CnnToRnnPreProcessor extends BaseInputPreProcessor {
private long inputHeight;
private long inputWidth;
private long numChannels;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java
index 55f0e6b12e95..bd2fa5bdbbca 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/ComposableInputPreProcessor.java
@@ -18,10 +18,13 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.workspace.ArrayType;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@@ -90,4 +93,11 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt
}
return new Pair<>(maskArray, currentMaskState);
}
+
+ @Override
+ public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
+ for(InputPreProcessor preProcessor : inputPreProcessors)
+ input = preProcessor.definePreProcess(sameDiff, input);
+ return input;
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java
index 305ace53012d..461e8de77336 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java
@@ -48,7 +48,7 @@
*/
@Data
@EqualsAndHashCode(exclude = {"shape"})
-public class FeedForwardToCnn3DPreProcessor implements InputPreProcessor {
+public class FeedForwardToCnn3DPreProcessor extends BaseInputPreProcessor {
private int inputDepth;
private int inputHeight;
private int inputWidth;
@@ -126,14 +126,10 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr
@Override
public FeedForwardToCnn3DPreProcessor clone() {
- try {
- FeedForwardToCnn3DPreProcessor clone = (FeedForwardToCnn3DPreProcessor) super.clone();
- if (clone.shape != null)
- clone.shape = clone.shape.clone();
- return clone;
- } catch (CloneNotSupportedException e) {
- throw new RuntimeException(e);
- }
+ FeedForwardToCnn3DPreProcessor clone = (FeedForwardToCnn3DPreProcessor) super.clone();
+ if (clone.shape != null)
+ clone.shape = clone.shape.clone();
+ return clone;
}
@Override
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java
index 817d538489e3..c7ef7e079493 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnnPreProcessor.java
@@ -20,6 +20,8 @@
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.common.primitives.Pair;
@@ -48,7 +50,7 @@
*/
@Data
@EqualsAndHashCode(exclude = {"shape"})
-public class FeedForwardToCnnPreProcessor implements InputPreProcessor {
+public class FeedForwardToCnnPreProcessor extends BaseInputPreProcessor {
private long inputHeight;
private long inputWidth;
private long numChannels;
@@ -115,14 +117,10 @@ public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr
@Override
public FeedForwardToCnnPreProcessor clone() {
- try {
- FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
- if (clone.shape != null)
- clone.shape = clone.shape.clone();
- return clone;
- } catch (CloneNotSupportedException e) {
- throw new RuntimeException(e);
- }
+ FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor) super.clone();
+ if (clone.shape != null)
+ clone.shape = clone.shape.clone();
+ return clone;
}
@Override
@@ -167,4 +165,13 @@ public Pair feedForwardMaskArray(INDArray maskArray, MaskSt
return new Pair<>(maskArray, currentMaskState);
}
+ @Override
+ public SDVariable definePreProcess(@NonNull SameDiff sameDiff, @NonNull SDVariable input) {
+ //TODO Assuming shape of input is correct, it would be better to check & throw exception here, but needs offline shape inference
+
+ if(numChannels == -1)
+ throw new IllegalStateException("Can't convert when numChannels isn't explicitly specified");
+ //TODO get batch size. Needs offline shape inference
+ return input.reshape(-1, numChannels, inputHeight, inputWidth);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java
index e6bca1bed18a..34b5417248f2 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java
@@ -48,7 +48,7 @@
*/
@Data
@NoArgsConstructor
-public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
+public class FeedForwardToRnnPreProcessor extends BaseInputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java
index 57487aae7267..27e077400116 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java
@@ -48,7 +48,7 @@
*/
@Data
@EqualsAndHashCode(exclude = {"product"})
-public class RnnToCnnPreProcessor implements InputPreProcessor {
+public class RnnToCnnPreProcessor extends BaseInputPreProcessor {
private int inputHeight;
private int inputWidth;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java
index d4355e38b651..507e436b80aa 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java
@@ -51,7 +51,7 @@
@Data
@Slf4j
@NoArgsConstructor
-public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
+public class RnnToFeedForwardPreProcessor extends BaseInputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java
index 2f7bd45eeecf..990b86f87ab1 100755
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java
@@ -24,7 +24,16 @@
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
+import org.deeplearning4j.nn.conf.layers.BaseLayer;
+import org.deeplearning4j.nn.conf.layers.LayerWithLoss;
+import org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer;
+import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater;
+import org.deeplearning4j.util.ToSameDiffUtils;
import org.nd4j.adapters.OutputAdapter;
+import org.nd4j.autodiff.samediff.NameScope;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.*;
@@ -91,6 +100,9 @@
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.learning.config.NoOp;
+import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
@@ -747,6 +759,208 @@ public void init(INDArray parameters, boolean cloneParametersArray) {
initCalled = true;
}
+ /**
+ *
+ * Create the MultiLayerNetwork in a SameDiff instance.
+ *
+ * The input and lables placeholders are created with names "input" and "labels", respectively.
+ * Output and loss variables are set on the SameDiff instance and can be gotten from it.
+ *
+ * @param sameDiff The SameDiff instance to create the model in
+ * @param inputTypes The types of the inputs.
+ * @param useView whether to directly use the (view) weights in the SDVariables, or create new ones.
+ * Using them saves an initialization (of every weight), but may cause issues with multi-gpu setups.
+ * @param skipErrors Whether to ignore updater or regularization configuration if they aren't the same on all layers.
+ * @return The {@link org.nd4j.autodiff.samediff.TrainingConfig} if training is setup (the last layer is an BaseOutputLayer), or null if not.
+ */
+ public org.nd4j.autodiff.samediff.TrainingConfig toSameDiff(@NonNull SameDiff sameDiff, @NonNull Map inputTypes, boolean useView, boolean skipErrors) {
+
+ if (!initCalled)
+ init();
+
+ Preconditions.checkArgument(inputTypes.keySet().equals(new HashSet<>(configuration.getNetworkInputs())), "Must specify input types for all inputs. Expected %s, but got %s.",
+ inputTypes.keySet(), configuration.getNetworkInputs());
+
+ InputType[] inputVertTypes = new InputType[inputTypes.size()];
+ int j = 0;
+ for(String inputName : configuration.getNetworkInputs()){
+ inputVertTypes[j] = inputTypes.get(inputName);
+ j++;
+ }
+
+ Map outputTypes = configuration.getLayerActivationTypes(true, inputVertTypes);
+
+ Map activations = new HashMap<>();
+ for(Map.Entry input : inputTypes.entrySet()){
+ activations.put(input.getKey(), sameDiff.placeHolder(input.getKey(), configuration.getDataType(), input.getValue().getShape(true)));
+ }
+
+ Map sdOutputLabels = new HashMap<>();
+
+ for (int i : topologicalOrder) {
+ GraphVertex vertex = vertices[i];
+ String name = vertex.getVertexName();
+
+ if(vertex instanceof InputVertex)
+ continue;
+
+ NameScope layerScope = sameDiff.withNameScope(name);
+
+ Map paramTable = ToSameDiffUtils.defineParams(sameDiff, vertex, useView);
+
+ SDVariable[] inputs = new SDVariable[vertex.getNumInputArrays()];
+ j = 0;
+ for(String inputVertex : configuration.getVertexInputs().get(name)){
+ inputs[j] = activations.get(inputVertex);
+ j++;
+ }
+
+ SDVariable output;
+ if(vertex.hasLayer() && vertex.getLayer() instanceof SameDiffOutputLayer){
+ String inputName = configuration.getVertexInputs().get(name).get(0);
+ SDVariable labels = null;
+ if(((SameDiffOutputLayer) vertex.getLayer()).needsLabels()){
+ labels = sameDiff
+ .placeHolder("labels", configuration.getDataType(), outputTypes.get(inputName).getShape(true));
+ }
+
+ SDVariable input = activations.get(inputName);
+ output = ((SameDiffOutputLayer) vertex.getLayer()).layerConf().defineLayerAndLoss(sameDiff, input, labels, paramTable);
+ sdOutputLabels.put(name, labels);
+ } else {
+ output = vertex.defineVertex(sameDiff, inputs, null, paramTable);
+ }
+
+ activations.put(name, output);
+
+ layerScope.close();
+ }
+
+ List sdOutputs = new ArrayList<>();
+ for(String vertex : configuration.getNetworkOutputs()){
+ sdOutputs.add(activations.get(vertex).name());
+ }
+
+ sameDiff.setOutputs(sdOutputs);
+
+ List losses = new ArrayList<>();
+ List allLabels = new ArrayList<>();
+
+ for(String output : configuration.getNetworkOutputs()){
+ GraphVertex vertex = verticesMap.get(output);
+ SDVariable loss;
+ SDVariable labels;
+ if(vertex.hasLayer() && vertex.getLayer() instanceof SameDiffOutputLayer) {
+ loss = activations.get(vertex.getVertexName());
+ labels = sdOutputLabels.get(vertex.getVertexName());
+
+ } else if(vertex.hasLayer() && vertex.getLayer() instanceof IOutputLayer && vertex.getLayer().conf().getLayer() instanceof LayerWithLoss){
+ LayerWithLoss lossLayer = (LayerWithLoss) vertex.getLayer().conf().getLayer();
+ SDVariable input = activations.get(output);
+ labels = null;
+
+ NameScope vertexScope = sameDiff.withNameScope(vertex.getVertexName());
+
+ if(((IOutputLayer) vertex.getLayer()).needsLabels()) {
+ labels = sameDiff
+ .placeHolder("labels", configuration.getDataType(), outputTypes.get(output).getShape(true));
+ }
+ NameScope lossScope = sameDiff.withNameScope("loss");
+
+ loss = lossLayer.defineLoss(sameDiff, input, labels, conf().isMiniBatch());
+ lossScope.close();
+
+ loss.rename("loss");
+
+ vertexScope.close();
+
+ } else {
+ continue;
+ }
+
+ losses.add(loss.name());
+ if(labels != null)
+ allLabels.add(labels.name());
+ }
+
+ if(losses.size() > 0){
+
+ IUpdater iUpdater = ToSameDiffUtils.getUpdater(layers, skipErrors);
+ List regularizations = ToSameDiffUtils.getRegularizations(layers, skipErrors);
+
+ String[] lossArr = losses.toArray(new String[0]);
+ sameDiff.setLossVariables(lossArr);
+
+ TrainingConfig.Builder tcBuilder = org.nd4j.autodiff.samediff.TrainingConfig.builder()
+ .minimize(lossArr)
+ .minimize(conf().isMinimize())
+ .dataSetFeatureMapping(configuration.getNetworkInputs().toArray(new String[0]));
+
+ if(regularizations != null)
+ tcBuilder.regularization(regularizations);
+
+ if(iUpdater != null)
+ tcBuilder.updater(iUpdater.clone());
+ else
+ tcBuilder.updater(new NoOp());
+
+ if(allLabels.size() == 0)
+ tcBuilder.markLabelsUnused();
+ else
+ tcBuilder.dataSetLabelMapping(allLabels.toArray(new String[0]));
+
+ org.nd4j.autodiff.samediff.TrainingConfig trainingConfig = tcBuilder.build();
+
+ trainingConfig.setIterationCount(getIterationCount());
+ trainingConfig.setEpochCount(getEpochCount());
+
+ sameDiff.setTrainingConfig(trainingConfig);
+
+ if(iUpdater != null) {
+ Updater updater = getUpdater();
+
+ if(updater instanceof BaseMultiLayerUpdater){
+ ToSameDiffUtils.copyUpdaterState(sameDiff, (BaseMultiLayerUpdater>) updater, null);
+ } else {
+ if(skipErrors)
+ log.warn("Unsupported updater type {}, not copying updater state to SameDiff", updater.getClass().getSimpleName());
+ else
+ throw new IllegalStateException("Unsupported updater type " + updater.getClass().getSimpleName() + ", could not updater state to SameDiff");
+ }
+
+
+ }
+
+ return trainingConfig;
+ }
+
+ return null;
+ }
+
+
+ /**
+ * See {@link #toSameDiff(SameDiff, Map, boolean, boolean)}. {@code useView} and {@code skipErrors} are true.
+ */
+ public TrainingConfig toSameDiff(@NonNull SameDiff sameDiff, @NonNull Map inputTypes){
+ return toSameDiff(sameDiff, inputTypes, true, true);
+ }
+
+ /**
+ * See {@link #toSameDiff(SameDiff, Map, boolean, boolean)}.
+ */
+ public SameDiff toSameDiff(@NonNull Map inputTypes, boolean useView, boolean skipErrors){
+ SameDiff sameDiff = SameDiff.create();
+ toSameDiff(sameDiff, inputTypes, useView, skipErrors);
+ return sameDiff;
+ }
+
+ /**
+ * See {@link #toSameDiff(SameDiff, Map, boolean, boolean)}. {@code useView} and {@code skipErrors} are true.
+ */
+ public SameDiff toSameDiff(@NonNull Map inputTypes){
+ return toSameDiff(inputTypes, true, true);
+ }
+
/**
* This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers.
* As a general rule, this shouldn't ever need to be called manually when doing training via fit(DataSet), fit(DataSetIterator)
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java
index 555c64f94d03..c2b13daab697 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java
@@ -18,10 +18,13 @@
import lombok.Data;
import lombok.Getter;
+import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@@ -76,6 +79,18 @@ protected BaseGraphVertex(ComputationGraph graph, String name, int vertexIndex,
this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)];
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
+
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params){
+ if(hasLayer())
+ getLayer().conf().getLayer().transformParamsForSameDiff(params);
+ }
+
@Override
public String getVertexName() {
return vertexName;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java
index de6b18428335..3e8128f3cc8b 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java
@@ -16,11 +16,14 @@
package org.deeplearning4j.nn.graph.vertex;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@@ -41,6 +44,17 @@ protected BaseWrapperVertex(GraphVertex underlying){
this.underlying = underlying;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ throw new UnsupportedOperationException("SameDiff conversion has not been implemented for " + this.getClass().getSimpleName());
+ }
+
+ @Override
+ public void transformParamsForSameDiff(@NonNull Map params){
+ underlying.transformParamsForSameDiff(params);
+ }
+
@Override
public String getVertexName() {
return underlying.getVertexName();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java
index a2477f9407c2..0a5610950658 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java
@@ -16,10 +16,13 @@
package org.deeplearning4j.nn.graph.vertex;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.gradient.Gradient;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@@ -92,6 +95,33 @@ public interface GraphVertex extends Trainable, Serializable {
/** Get the Layer (if any). Returns null if {@link #hasLayer()} == false */
Layer getLayer();
+ /**
+ * Define the vertex for conversion to {@link SameDiff}.
+ * If this isn't supported, this method should throw a {@link UnsupportedOperationException}
+ * like the default implementation in {@link BaseGraphVertex}.
+ *
+ * @param sameDiff The {@link SameDiff} instance to define in.
+ * @param inputs The inputs to the vertex, in the same order as {@link #getInputVertices()}.
+ * @param mask The mask. May be null.
+ * @param paramTable The parameters for the vertex. Keys will be the same as {@link #paramTable(boolean)}.
+ * @return The output of the vertex.
+ */
+ SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs, SDVariable mask,
+ @NonNull Map paramTable);
+
+ /**
+ * Do any necessary transforms to parameters (weights, biases, etc) before making SDVariables out of them.
+ * Useful for things like changing the dimension order or squeezing.
+ *
+ * Adding or removing parameters is supported.
+ *
+ * Should throw a {@link UnsupportedOperationException} if conversion of this layer configuration isn't
+ * supported and it will cause an error when transforming weights.
+ *
+ * @param params The parameter.
+ */
+ void transformParamsForSameDiff(@NonNull Map params);
+
/** Set the input activations.
* @param inputNumber Must be in range 0 to {@link #getNumInputArrays()}-1
* @param input The input array
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java
index 5018dbe71829..1a18b07eae16 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -76,6 +80,36 @@ public Layer getLayer() {
return null;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ if(inputs.length == 1)
+ return inputs[0];
+
+ if (op == Op.Subtract && inputs.length != 2)
+ throw new IllegalArgumentException("ElementWise subtraction only supports 2 inputs");
+
+ SDVariable acc = inputs[0];
+ for(int i = 1 ; i < inputs.length ; i++){
+ SDVariable next = inputs[i];
+ if(op == Op.Add)
+ acc = acc.add(next);
+ else if(op == Op.Subtract)
+ acc = acc.sub(next);
+ else if(op == Op.Product)
+ acc = acc.mul(next);
+ else if(op == Op.Average)
+ acc = acc.add(next);
+ else if(op == Op.Max)
+ acc = sameDiff.math.max(acc, next);
+ }
+
+ if(op == Op.Average)
+ acc = acc.div(inputs.length);
+
+ return acc;
+ }
+
@Override
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
if (!canDoForward())
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java
index 955ba8aba274..07e6648bf6a9 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java
@@ -16,16 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
-import lombok.AllArgsConstructor;
+import java.util.Map;
import lombok.EqualsAndHashCode;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.TrainingConfig;
-import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.misc.DummyConfig;
import org.deeplearning4j.nn.graph.vertex.BaseWrapperVertex;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.learning.config.IUpdater;
-import org.nd4j.linalg.learning.config.NoOp;
+import org.nd4j.autodiff.samediff.NameScope;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
/**
* FrozenVertex is used for the purposes of transfer learning
@@ -48,4 +48,16 @@ public TrainingConfig getConfig(){
}
return config;
}
+
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ for(SDVariable variable : paramTable.values()){
+ variable.convertToConstant();
+ }
+ NameScope underlyingScope = sameDiff.withNameScope("underlying");
+ SDVariable output = underlying.defineVertex(sameDiff, inputs, mask, paramTable);
+ underlyingScope.close();
+ return output;
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java
index 32e67134514a..de92a3249357 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@@ -37,6 +41,12 @@ public InputVertex(ComputationGraph graph, String name, int vertexIndex, VertexI
super(graph, name, vertexIndex, null, outputVertices, dataType);
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ throw new IllegalStateException("InputVertices should never be manually converted to SameDiff");
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java
index 894c026eda24..37c29ecf41eb 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -59,6 +63,17 @@ public L2NormalizeVertex(ComputationGraph graph, String name, int vertexIndex, V
this.eps = eps;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ if(dimension == null || dimension.length < 1)
+ throw new IllegalStateException("Dimension must be set for toSameDiff conversion.");
+
+ SDVariable factor = sameDiff.max(inputs[0].norm2(dimension), sameDiff.constant(Nd4j.scalar(inputs[0].dataType(), eps)));
+ return inputs[0].div(factor);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java
index 20394880caab..7755166fc3cc 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -54,6 +58,16 @@ public L2Vertex(ComputationGraph graph, String name, int vertexIndex, VertexIndi
this.eps = eps;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ SDVariable temp = inputs[0].sub(inputs[1]);
+ temp = temp.mul(temp);
+ temp = temp.reshape(sameDiff.concat(0, sameDiff.sizeAt(temp, 0), sameDiff.constant(Nd4j.scalar(DataType.INT64, -1))))
+ .sum(true, 1);
+ return sameDiff.math.sqrt(temp);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java
index d803808dff42..357dc6259de0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java
@@ -18,6 +18,7 @@
import lombok.Data;
import lombok.EqualsAndHashCode;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
@@ -31,6 +32,9 @@
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
+import org.nd4j.autodiff.samediff.NameScope;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@@ -76,6 +80,28 @@ public LayerVertex(ComputationGraph graph, String name, int vertexIndex, VertexI
this.inputs = new INDArray[(inputVertices != null ? inputVertices.length : 0)];
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ org.deeplearning4j.nn.conf.layers.Layer layerConf = layer.conf().getLayer();
+
+ InputPreProcessor preProcessor = getLayerPreProcessor();
+
+ SDVariable input = inputs[0];
+
+ if(preProcessor != null){
+ NameScope preProcessorScope = sameDiff.withNameScope("inputPreprocessor");
+ input = preProcessor.definePreProcess(sameDiff, input);
+ preProcessorScope.close();
+ }
+
+ if(layerConf.getIDropout() != null){
+ input = layerConf.getIDropout().defineDropout(sameDiff, input);
+ }
+
+ return layerConf.defineLayer(sameDiff, input, null, paramTable);
+ }
+
@Override
public boolean hasLayer() {
return true;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java
index 702767e8a573..787f2498b951 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
@@ -23,6 +25,8 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -61,6 +65,15 @@ public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexI
this.mergeAxis = mergeAxis;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ if(inputs.length == 1)
+ return inputs[0];
+
+ return sameDiff.concat(mergeAxis, inputs);
+ }
+
@Override
public String toString() {
return "MergeVertex(id=" + this.getVertexIndex() + ",name=\"" + this.getVertexName() + "\")";
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java
index 962b020817cb..fec76a507640 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java
@@ -16,12 +16,17 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
@@ -49,6 +54,16 @@ public PoolHelperVertex(ComputationGraph graph, String name, int vertexIndex, Ve
super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ if (inputs.length > 1)
+ throw new IllegalStateException("PoolHelper vertex requires a single input.");
+
+ return inputs[0].get(SDIndex.all(), SDIndex.all(), SDIndex.interval(1, -1), SDIndex.interval(1, -1));
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java
index d8fe856174e1..c493c84034d7 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor;
@@ -23,6 +25,8 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@@ -46,6 +50,12 @@ public PreprocessorVertex(ComputationGraph graph, String name, int vertexIndex,
this.preProcessor = preProcessor;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ return preProcessor.definePreProcess(sameDiff, inputs[0]);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java
index 39fcac462cdb..f69b64ba7ed9 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
@@ -24,6 +26,8 @@
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@@ -54,6 +58,16 @@ public ReshapeVertex(ComputationGraph graph, String name, int vertexIndex, Verte
this.maskShape = maskShape;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ if (inputs.length > 1)
+ throw new IllegalStateException("Reshape vertex requires a single input.");
+
+ return inputs[0].reshape(newShape);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java
index c4fa89239b68..924e41901393 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -50,6 +54,17 @@ public ScaleVertex(ComputationGraph graph, String name, int vertexIndex, VertexI
this.scaleFactor = scaleFactor;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ if (inputs.length > 1)
+ throw new IllegalArgumentException(
+ "ScaleVertex (name " + vertexName + " idx " + vertexIndex + ") only supports 1 input.");
+
+ return inputs[0].mul(scaleFactor);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java
index d9f5c78de6f9..74589f3a65b6 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -59,6 +63,17 @@ public ShiftVertex(ComputationGraph graph, String name, int vertexIndex, VertexI
this.shiftFactor = shiftFactor;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+
+ if (inputs.length > 1)
+ throw new IllegalArgumentException(
+ "ShiftVertex (name " + vertexName + " idx " + vertexIndex + ") only supports 1 input.");
+
+ return inputs[0].add(shiftFactor);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java
index 3be9d6895581..b330fef451b8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
@@ -23,6 +25,8 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java
index db44492935f8..e8af7a29caf7 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java
index 9eb4151c8193..b60e5175676d 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java
@@ -16,14 +16,20 @@
package org.deeplearning4j.nn.graph.vertex.impl;
+import java.util.Map;
+import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.primitives.Pair;
@@ -59,6 +65,13 @@ public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, Verte
this.stackSize = stackSize;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ //TODO no way to calculate step as an int or get with a SDVariable
+ return super.defineVertex(sameDiff, inputs, mask, paramTable);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java
index 8b3f2fba0b12..126cd8b64871 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.nn.graph.vertex.impl.rnn;
+import java.util.Map;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
@@ -23,6 +25,8 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java
index 75ce3be3b491..9dfbf302f27d 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java
@@ -16,6 +16,8 @@
package org.deeplearning4j.nn.graph.vertex.impl.rnn;
+import java.util.Map;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
@@ -23,6 +25,9 @@
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
+import org.nd4j.autodiff.samediff.SDIndex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@@ -65,6 +70,12 @@ public LastTimeStepVertex(ComputationGraph graph, String name, int vertexIndex,
+ "of network inputs (" + graph.getConfiguration().getNetworkInputs() + ")");
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ return inputs[0].get(SDIndex.all(), SDIndex.all(), SDIndex.point(-1));
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java
index 0d75119de0d4..8ece4f90b87a 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java
@@ -16,12 +16,16 @@
package org.deeplearning4j.nn.graph.vertex.impl.rnn;
+import java.util.Map;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
@@ -63,6 +67,12 @@ public ReverseTimeSeriesVertex(ComputationGraph graph, String name, int vertexIn
}
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map paramTable) {
+ return sameDiff.reverse(inputs[0], 2);
+ }
+
@Override
public boolean hasLayer() {
return false;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java
index c1c92d3a7322..016ec73a4687 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java
@@ -75,7 +75,12 @@ public INDArray activate(boolean training, LayerWorkspaceMgr mgr) {
//dup required: need to keep original input for backprop
in = mgr.dup(ArrayType.ACTIVATIONS, input, input.ordering());
} else {
- in = mgr.leverageTo(ArrayType.ACTIVATIONS, input);
+ if(mgr.isScopedOut(ArrayType.ACTIVATIONS) && !input.isAttached()) {
+ //Edge case: input and output are both not in workspaces - dup to avoid inplace modification
+ in = mgr.dup(ArrayType.ACTIVATIONS, input);
+ } else {
+ in = mgr.leverageTo(ArrayType.ACTIVATIONS, input);
+ }
}
return layerConf().getActivationFn().getActivation(in, training);
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
index a10bb33f3333..fcbe633216a0 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java
@@ -354,4 +354,8 @@ public boolean isPretrainLayer() {
public boolean hasBias() {
return layerConf().hasBias();
}
+
+ public ILossFunction getLossFn() {
+ return layerConf().getLossFn();
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java
index 7180ff446d7b..b9ad155a1a0c 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java
@@ -337,4 +337,8 @@ protected INDArray getLabels2d() {
return labels;
}
+ public ILossFunction getLossFn() {
+ return layerConf().getLossFn();
+ }
+
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java
index 212161a9e02b..1130ccdc7848 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java
@@ -218,6 +218,15 @@ public boolean needsLabels() {
@Override
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
+
+ assertInputSet(false);
+ if (input.rank() != 5)
+ throw new UnsupportedOperationException(
+ "Input is not rank 5. Got input with rank " + input.rank() + " " + layerId() + " with shape "
+ + Arrays.toString(input.shape()) + " - expected shape [minibatch,channels,depth,height,width]");
+ if (labels == null)
+ throw new IllegalStateException("Labels are not set (null)");
+
INDArray input2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), input, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray labels2d = ConvolutionUtils.reshape5dTo2d(layerConf().getDataFormat(), labels, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray maskReshaped = ConvolutionUtils.reshapeCnn3dMask(layerConf().getDataFormat(), maskArray, input, workspaceMgr, ArrayType.FF_WORKING_MEM);
@@ -279,4 +288,8 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, summedScores);
}
+
+ public ILossFunction getLossFn() {
+ return layerConf().getLossFn();
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java
index 06c3b237544c..c2338cd1affd 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/CnnLossLayer.java
@@ -198,6 +198,16 @@ public boolean needsLabels() {
@Override
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
+ assertInputSet(false);
+ if (input.rank() != 4)
+ throw new UnsupportedOperationException(
+ "Input is not rank 4. Got input with rank " + input.rank() + " " + layerId() + " with shape "
+ + Arrays.toString(input.shape()) + " - expected shape " + layerConf().getFormat().dimensionNames());
+ if (labels == null)
+ throw new IllegalStateException("Labels are not set (null)");
+
+ Preconditions.checkState(input.equalShapes(labels), "Input and label arrays do not have same shape: %ndShape vs. %ndShape",input, labels);
+
INDArray input2d = ConvolutionUtils.reshape4dTo2d(input, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray labels2d = ConvolutionUtils.reshape4dTo2d(labels, workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray maskReshaped = ConvolutionUtils.reshapeMaskIfRequired(maskArray, input, layerConf().getFormat(), workspaceMgr, ArrayType.FF_WORKING_MEM);
@@ -248,4 +258,8 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, summedScores);
}
+
+ public ILossFunction getLossFn() {
+ return layerConf().getLossFn();
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java
index 4d118c62bf0d..d160854bdcf7 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java
@@ -680,4 +680,8 @@ public INDArray getProbabilityMatrix(INDArray networkOutput, int example, int cl
INDArray conf = networkOutput.get(point(example), point(5*bbs + classNumber), all(), all());
return conf;
}
+
+ public ILossFunction getLossFn() {
+ return null;
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java
index e6c14f2e4703..ec26baf7b18a 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayer.java
@@ -18,6 +18,7 @@
import lombok.Getter;
+import lombok.NonNull;
import lombok.Setter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@@ -26,6 +27,9 @@
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.api.buffer.DataType;
@@ -34,6 +38,7 @@
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
+import org.nd4j.linalg.lossfunctions.BaseLossFunction;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair;
@@ -289,7 +294,7 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr
return summedScores;
}
- public class OCNNLossFunction implements ILossFunction {
+ public class OCNNLossFunction extends BaseLossFunction {
@Override
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java
index 28913681f67d..ced54fec3bda 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java
@@ -213,6 +213,22 @@ public boolean needsLabels() {
@Override
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
+ assertInputSet(false);
+ if (input.rank() != 3)
+ throw new UnsupportedOperationException(
+ "Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
+ input.rank() + " with shape " + Arrays.toString(input.shape()) + " for layer " + layerId());
+ if (labels == null)
+ throw new IllegalStateException("Labels are not set (null)");
+
+ if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
+ input = input.permute(0, 2, 1);
+ labels = labels.permute(0, 2, 1);
+ }
+ Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
+ Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
+ "Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
+
INDArray input = this.input;
INDArray labels = this.labels;
if (layerConf().getRnnDataFormat() == RNNFormat.NWC){
@@ -288,4 +304,8 @@ public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr
return summedScores;
}
+
+ public ILossFunction getLossFn() {
+ return layerConf().getLossFn();
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java
index 979d2b7be23f..0bdf77cb63b2 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java
@@ -109,6 +109,12 @@ public Layer.Type type() {
protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false);
if (input.rank() == 3) {
+
+ RNNFormat format = layerConf().getRnnDataFormat();
+ int td = (format == RNNFormat.NCW) ? 2 : 1;
+ Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
+ Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
+ "Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
//Case when called from RnnOutputLayer
INDArray inputTemp = input;
input = (layerConf().getRnnDataFormat()==RNNFormat.NWC)? input.permute(0, 2, 1):input;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java
index 7fc7af03f59f..496b4ccbe1f2 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java
@@ -16,6 +16,7 @@
package org.deeplearning4j.nn.layers.samediff;
+import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
@@ -82,6 +83,23 @@ public SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String
this.params = paramsView;
}
+ @Override
+ public SDVariable defineVertex(@NonNull SameDiff sameDiff, @NonNull SDVariable[] inputs,
+ SDVariable mask, @NonNull Map