Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ public void lossFunctionGradientCheck() {
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT(),
new SDLossMAE(), new SDLossMSE()
new SDLossMAE(), new SDLossMSE(),
new LossFloat(new LossMSE(), 0.1), new LossFloat(new LossMCXENT(), 0.1)
};

Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
Expand Down Expand Up @@ -132,7 +133,9 @@ public void lossFunctionGradientCheck() {
Activation.TANH, // SDLossMAE
Activation.SOFTMAX, // SDLossMSE
Activation.SIGMOID, // SDLossMSE
Activation.TANH //SDLossMSE
Activation.TANH, //SDLossMSE
Activation.IDENTITY,// Float Loss + MSE
Activation.SOFTMAX // Float Loss + MCXENT
};

int[] nOut = new int[] {1, //xent
Expand Down Expand Up @@ -174,6 +177,8 @@ public void lossFunctionGradientCheck() {
3, // SDLossMSE
3, // SDLossMSE
3, // SDLossMSE
3, // Float Loss + MSE
3 // Float Loss + MCXENT
};

int[] minibatchSizes = new int[] {1, 3};
Expand Down Expand Up @@ -255,7 +260,9 @@ public void lossFunctionGradientCheckLossLayer() {
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT()
new LossSparseMCXENT(),
new LossFloat(new LossMSE(), 0.1), new LossFloat(new LossMCXENT(), 0.1)

};

Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
Expand Down Expand Up @@ -289,7 +296,9 @@ public void lossFunctionGradientCheckLossLayer() {
Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel
Activation.IDENTITY, // Wasserstein
Activation.SOFTMAX
Activation.SOFTMAX,
Activation.IDENTITY,// Float Loss + MSE
Activation.SOFTMAX // Float Loss + MCXENT
};

int[] nOut = new int[] {1, //xent
Expand Down Expand Up @@ -323,7 +332,9 @@ public void lossFunctionGradientCheckLossLayer() {
10, // Mixture Density + tanh
10, // MultiLabel
2, // Wasserstein
4
4,
3, // Float Loss + MSE
3 // Float Loss + MCXENT
};

int[] minibatchSizes = new int[] {1, 3};
Expand Down Expand Up @@ -482,6 +493,7 @@ public static INDArray[] getFeaturesAndLabels(ILossFunction l, long[] featuresSh
throw new RuntimeException();
}
break;
case "LossFloat":
case "LossMCXENT":
case "LossNegativeLogLikelihood":
ret[1] = Nd4j.zeros(labelsShape);
Expand Down Expand Up @@ -625,7 +637,9 @@ public void lossFunctionWeightedGradientCheck() {
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(w), new LossL1(w), new LossL1(w),
new LossL2(w), new LossL2(w), new LossMAE(w), new LossMAE(w), new LossMAPE(w),
new LossMAPE(w), new LossMCXENT(w), new LossMSE(w), new LossMSE(w), new LossMSLE(w),
new LossMSLE(w), new LossNegativeLogLikelihood(w), new LossNegativeLogLikelihood(w),};
new LossMSLE(w), new LossNegativeLogLikelihood(w), new LossNegativeLogLikelihood(w),
new LossFloat(new LossNegativeLogLikelihood(w), 0.1)
};

Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
Activation.TANH, //l1
Expand All @@ -643,6 +657,7 @@ public void lossFunctionWeightedGradientCheck() {
Activation.SOFTMAX, //msle + softmax
Activation.SIGMOID, //nll
Activation.SOFTMAX, //nll + softmax
Activation.SOFTMAX, //float + nll + softmax
};

int[] minibatchSizes = new int[] {1, 3};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.linalg.lossfunctions.impl;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;

/**
* Implementation of "Do We Need Zero Training Loss After Achieving Zero Training Error?"
* https://arxiv.org/abs/2002.08709
*
* Wraps any other Loss Function that is optimized towards zero and adds a "float value" to it.
*
* Intuitively this means that your model will not be able to get below a specific loss value. If the loss on an example
* gets below this float value, the gradient for this example gets inverted, i.e. instead of descending on this gradient
* we will ascent on it instead.
*
* The effect of using this Loss function is regularization: It makes it harder for the model to over-fit on the
* training data. The float level is a hyper parameter, and selecting a good one, can take a few tries.
*
* A float value of 0 is the same as not using this loss function at all, everything above it is a valid value, but you
* should likely stay below 0.3. Overall, a good starting point is to use the loss score of your best performing model
* when using early stopping.
*
* @author Paul Dubs
*/
@Getter
@Setter
@EqualsAndHashCode
public class LossFloat implements ILossFunction {

private ILossFunction wrapped;
private double floatLevel;

// For (De-)Serialization
private LossFloat(){}

public LossFloat(ILossFunction wrapped, double floatLevel) {
this.wrapped = wrapped;
this.floatLevel = floatLevel;
}

@Override
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
boolean average) {
INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask);

double score = scoreArr.sumNumber().doubleValue();

if (average)
score /= scoreArr.size(0);

return score;
}

@Override
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
return Nd4j.math.abs(wrapped.computeScoreArray(labels, preOutput, activationFn, mask).subi(floatLevel)).addi(floatLevel);
}

@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray gradient = wrapped.computeGradient(labels, preOutput, activationFn, mask);
return gradient.muli(Nd4j.math.sign(wrapped.computeScoreArray(labels, preOutput, activationFn, mask).sub(floatLevel)));
}

@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels,
INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average),
computeGradient(labels, preOutput, activationFn, mask));
}

@Override
public String name() {
return toString();
}

@Override
public String toString() {
return "LossFloatWrapper("+wrapped.name()+", "+floatLevel+")";
}
}