diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index ab6c8215a064..d9e7923b8a23 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -91,7 +91,8 @@ public void lossFunctionGradientCheck() { LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), new LossMultiLabel(), new LossWasserstein(), new LossSparseMCXENT(), - new SDLossMAE(), new SDLossMSE() + new SDLossMAE(), new SDLossMSE(), + new LossFloat(new LossMSE(), 0.1), new LossFloat(new LossMCXENT(), 0.1) }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -132,7 +133,9 @@ public void lossFunctionGradientCheck() { Activation.TANH, // SDLossMAE Activation.SOFTMAX, // SDLossMSE Activation.SIGMOID, // SDLossMSE - Activation.TANH //SDLossMSE + Activation.TANH, //SDLossMSE + Activation.IDENTITY,// Float Loss + MSE + Activation.SOFTMAX // Float Loss + MCXENT }; int[] nOut = new int[] {1, //xent @@ -174,6 +177,8 @@ public void lossFunctionGradientCheck() { 3, // SDLossMSE 3, // SDLossMSE 3, // SDLossMSE + 3, // Float Loss + MSE + 3 // Float Loss + MCXENT }; int[] minibatchSizes = new int[] {1, 3}; @@ -255,7 +260,9 @@ public void lossFunctionGradientCheckLossLayer() { new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), new LossMultiLabel(), new LossWasserstein(), - new LossSparseMCXENT() + new LossSparseMCXENT(), + new LossFloat(new LossMSE(), 0.1), new LossFloat(new LossMCXENT(), 0.1) + }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -289,7 +296,9 @@ public void lossFunctionGradientCheckLossLayer() { Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MultiLabel Activation.IDENTITY, // Wasserstein - Activation.SOFTMAX + Activation.SOFTMAX, + Activation.IDENTITY,// Float Loss + MSE + Activation.SOFTMAX // Float Loss + MCXENT }; int[] nOut = new int[] {1, //xent @@ -323,7 +332,9 @@ public void lossFunctionGradientCheckLossLayer() { 10, // Mixture Density + tanh 10, // MultiLabel 2, // Wasserstein - 4 + 4, + 3, // Float Loss + MSE + 3 // Float Loss + MCXENT }; int[] minibatchSizes = new int[] {1, 3}; @@ -482,6 +493,7 @@ public static INDArray[] getFeaturesAndLabels(ILossFunction l, long[] featuresSh throw new RuntimeException(); } break; + case "LossFloat": case "LossMCXENT": case "LossNegativeLogLikelihood": ret[1] = Nd4j.zeros(labelsShape); @@ -625,7 +637,9 @@ public void lossFunctionWeightedGradientCheck() { ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(w), new LossL1(w), new LossL1(w), new LossL2(w), new LossL2(w), new LossMAE(w), new LossMAE(w), new LossMAPE(w), new LossMAPE(w), new LossMCXENT(w), new LossMSE(w), new LossMSE(w), new LossMSLE(w), - new LossMSLE(w), new LossNegativeLogLikelihood(w), new LossNegativeLogLikelihood(w),}; + new LossMSLE(w), new LossNegativeLogLikelihood(w), new LossNegativeLogLikelihood(w), + new LossFloat(new LossNegativeLogLikelihood(w), 0.1) + }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent Activation.TANH, //l1 @@ -643,6 +657,7 @@ public void lossFunctionWeightedGradientCheck() { Activation.SOFTMAX, //msle + softmax Activation.SIGMOID, //nll Activation.SOFTMAX, //nll + softmax + Activation.SOFTMAX, //float + nll + softmax }; int[] minibatchSizes = new int[] {1, 3}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFloat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFloat.java new file mode 100644 index 000000000000..eb064569d308 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossFloat.java @@ -0,0 +1,103 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.lossfunctions.impl; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.primitives.Pair; + +/** + * Implementation of "Do We Need Zero Training Loss After Achieving Zero Training Error?" + * https://arxiv.org/abs/2002.08709 + * + * Wraps any other Loss Function that is optimized towards zero and adds a "float value" to it. + * + * Intuitively this means that your model will not be able to get below a specific loss value. If the loss on an example + * gets below this float value, the gradient for this example gets inverted, i.e. instead of descending on this gradient + * we will ascent on it instead. + * + * The effect of using this Loss function is regularization: It makes it harder for the model to over-fit on the + * training data. The float level is a hyper parameter, and selecting a good one, can take a few tries. + * + * A float value of 0 is the same as not using this loss function at all, everything above it is a valid value, but you + * should likely stay below 0.3. Overall, a good starting point is to use the loss score of your best performing model + * when using early stopping. + * + * @author Paul Dubs + */ +@Getter +@Setter +@EqualsAndHashCode +public class LossFloat implements ILossFunction { + + private ILossFunction wrapped; + private double floatLevel; + + // For (De-)Serialization + private LossFloat(){} + + public LossFloat(ILossFunction wrapped, double floatLevel) { + this.wrapped = wrapped; + this.floatLevel = floatLevel; + } + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, + boolean average) { + INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask); + + double score = scoreArr.sumNumber().doubleValue(); + + if (average) + score /= scoreArr.size(0); + + return score; + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + return Nd4j.math.abs(wrapped.computeScoreArray(labels, preOutput, activationFn, mask).subi(floatLevel)).addi(floatLevel); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray gradient = wrapped.computeGradient(labels, preOutput, activationFn, mask); + return gradient.muli(Nd4j.math.sign(wrapped.computeScoreArray(labels, preOutput, activationFn, mask).sub(floatLevel))); + } + + @Override + public Pair computeGradientAndScore(INDArray labels, + INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average), + computeGradient(labels, preOutput, activationFn, mask)); + } + + @Override + public String name() { + return toString(); + } + + @Override + public String toString() { + return "LossFloatWrapper("+wrapped.name()+", "+floatLevel+")"; + } +}