Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ class NormalizationContextIntegTest extends SparkTestUtils with GameTestUtils {

// Train the original data with a loss function binding normalization
val zero = Vector.zeros[Double](heartDataRDD.first.features.length)
val (model1, objective1) = optimizerNorm.optimize(objectiveFunction, zero)(heartDataRDD)
val (model1, _, objective1) = optimizerNorm.optimize(objectiveFunction, zero)(heartDataRDD)
// Train the transformed data with a normal loss function
val (model2, objective2) = optimizerNoNorm.optimize(objectiveFunction, zero)(transformedRDD)
val (model2, _, objective2) = optimizerNoNorm.optimize(objectiveFunction, zero)(transformedRDD)

heartDataRDD.unpersist()
transformedRDD.unpersist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,20 @@ class BaseGLMIntegTest extends SparkTestUtils {
val optimizationProblem = optimizationProblemBuilder(normalizationContext)

// Step 2: Run optimization
val models = BaseGLMIntegTest.LAMBDAS.map { lambda =>
val modelsAndStateTrackers = BaseGLMIntegTest.LAMBDAS.map { lambda =>
optimizationProblem.updateRegularizationWeight(lambda)
val result = optimizationProblem.run(trainingSet)
val statesTracker = optimizationProblem.getStatesTracker
val (model, statesTracker) = optimizationProblem.run(trainingSet)

// Step 3: Check convergence
BaseGLMIntegTest.checkConvergence(statesTracker)

result
(model, statesTracker)
}

// Step 4: Validate the models
models.foreach( m => {
m.validateCoefficients()
validator.validateModelPredictions(m, trainingSet)
modelsAndStateTrackers.foreach( t => {
t._1.validateCoefficients()
validator.validateModelPredictions(t._1, trainingSet)
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ object ModelTraining extends Logging {
// Initialize the list with the result from the first regularization weight
optimizationProblem.updateRegularizationWeight(currentWeight)

val glm = if (numWarmStartModels == 0) {
val (glm, stateTracker) = if (numWarmStartModels == 0) {

logger.info(s"No warm start model found; beginning training with a 0-coefficients model")

Expand All @@ -199,14 +199,14 @@ object ModelTraining extends Logging {
optimizationProblem.run(trainingData, warmStartModels(maxLambda))
}

List((currentWeight, glm, optimizationProblem.getStatesTracker))
List((currentWeight, glm, stateTracker))

case (latestWeightsModelsAndTrackers, currentWeight) =>

optimizationProblem.updateRegularizationWeight(currentWeight)

// Train the rest of the models
val glm = if (useWarmStart) {
val (glm, stateTracker) = if (useWarmStart) {
val previousModel = latestWeightsModelsAndTrackers.head._2

logger.info(s"Training model with regularization weight $currentWeight started (warm start)")
Expand All @@ -219,7 +219,7 @@ object ModelTraining extends Logging {
optimizationProblem.run(trainingData)
}

(currentWeight, glm, optimizationProblem.getStatesTracker) +: latestWeightsModelsAndTrackers
(currentWeight, glm, stateTracker) +: latestWeightsModelsAndTrackers
}

broadcastNormalizationContext.unpersist()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ object FixedEffectCoordinate {

val newModel = initialFixedEffectModelOpt
.map { initialFixedEffectModel =>
optimizationProblem.runWithSampling(input, initialFixedEffectModel.model)
val (model, _) = optimizationProblem.runWithSampling(input, initialFixedEffectModel.model)
model
}
.getOrElse(optimizationProblem.runWithSampling(input))
.getOrElse(optimizationProblem.runWithSampling(input)._1)
val updatedModelBroadcast = input.sparkContext.broadcast(newModel)

new FixedEffectModel(updatedModelBroadcast, featureShardId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package com.linkedin.photon.ml.algorithm

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

import com.linkedin.photon.ml.data._
Expand Down Expand Up @@ -246,8 +245,7 @@ object RandomEffectCoordinate {
.mapValues {
case (Some(localModel), Some((localDataset, optimizationProblem))) =>
val trainingLabeledPoints = localDataset.dataPoints.map(_._2)
val updatedModel = optimizationProblem.run(trainingLabeledPoints, localModel)
val stateTrackers = optimizationProblem.getStatesTracker
val (updatedModel, stateTrackers) = optimizationProblem.run(trainingLabeledPoints, localModel)

(updatedModel, Some(stateTrackers))

Expand All @@ -256,8 +254,7 @@ object RandomEffectCoordinate {

case (None, Some((localDataset, optimizationProblem))) =>
val trainingLabeledPoints = localDataset.dataPoints.map(_._2)
val updatedModel = optimizationProblem.run(trainingLabeledPoints)
val stateTrackers = optimizationProblem.getStatesTracker
val (updatedModel,stateTrackers) = optimizationProblem.run(trainingLabeledPoints)

(updatedModel, Some(stateTrackers))

Expand All @@ -274,10 +271,9 @@ object RandomEffectCoordinate {
.getOrElse {
val modelsAndTrackers = dataAndOptimizationProblems.mapValues { case (localDataset, optimizationProblem) =>
val trainingLabeledPoints = localDataset.dataPoints.map(_._2)
val newModel = optimizationProblem.run(trainingLabeledPoints)
val stateTrackers = optimizationProblem.getStatesTracker
val (model, stateTrackers) = optimizationProblem.run(trainingLabeledPoints)

(newModel, stateTrackers)
(model, stateTrackers)
}
modelsAndTrackers.persist(StorageLevel.MEMORY_ONLY_SER)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
*/
package com.linkedin.photon.ml.optimization

import breeze.linalg.{Vector, cholesky, diag}
import breeze.linalg.{cholesky, diag, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -109,8 +109,9 @@ protected[ml] class DistributedOptimizationProblem[Objective <: DistributedObjec
* @param input The training data
* @return The learned [[GeneralizedLinearModel]]
*/
override def run(input: RDD[LabeledPoint]): GeneralizedLinearModel =
override def run(input: RDD[LabeledPoint]): (GeneralizedLinearModel, OptimizationStatesTracker) = {
run(input, initializeZeroModel(input.first.features.size))
}

/**
* Run the algorithm with the configured parameters, starting from the initial model provided
Expand All @@ -120,13 +121,13 @@ protected[ml] class DistributedOptimizationProblem[Objective <: DistributedObjec
* @param initialModel The initial model from which to begin optimization
* @return The learned [[GeneralizedLinearModel]]
*/
override def run(input: RDD[LabeledPoint], initialModel: GeneralizedLinearModel): GeneralizedLinearModel = {
override def run(input: RDD[LabeledPoint], initialModel: GeneralizedLinearModel): (GeneralizedLinearModel, OptimizationStatesTracker) = {

val normalizationContext = optimizer.getNormalizationContext
val (optimizedCoefficients, _) = optimizer.optimize(objectiveFunction, initialModel.coefficients.means)(input)
val (optimizedCoefficients, stateTracker, _) = optimizer.optimize(objectiveFunction, initialModel.coefficients.means)(input)
val optimizedVariances = computeVariances(input, optimizedCoefficients)

createModel(normalizationContext, optimizedCoefficients, optimizedVariances)
(createModel(normalizationContext, optimizedCoefficients, optimizedVariances), stateTracker)
}

/**
Expand All @@ -136,7 +137,7 @@ protected[ml] class DistributedOptimizationProblem[Objective <: DistributedObjec
* @param input The training data
* @return The learned [[GeneralizedLinearModel]]
*/
def runWithSampling(input: RDD[(UniqueSampleId, LabeledPoint)]): GeneralizedLinearModel =
def runWithSampling(input: RDD[(UniqueSampleId, LabeledPoint)]): (GeneralizedLinearModel, OptimizationStatesTracker) =
runWithSampling(input, initializeZeroModel(input.first._2.features.size))

/**
Expand All @@ -149,19 +150,19 @@ protected[ml] class DistributedOptimizationProblem[Objective <: DistributedObjec
*/
def runWithSampling(
input: RDD[(UniqueSampleId, LabeledPoint)],
initialModel: GeneralizedLinearModel): GeneralizedLinearModel = {
initialModel: GeneralizedLinearModel): (GeneralizedLinearModel, OptimizationStatesTracker) = {

val data = (samplerOption match {
case Some(sampler) => sampler.downSample(input).values
case None => input.values
})
.setName("In memory fixed effect training dataset")
.persist(StorageLevel.MEMORY_AND_DISK)
val result = run(data, initialModel)
val (model, stateTracker) = run(data, initialModel)

data.unpersist()

result
(model, stateTracker)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ protected[ml] abstract class GeneralizedLinearOptimizationProblem[Objective <: O
* @param input The training data
* @return The learned GLM for the given optimization problem, data, regularization type, and regularization weight
*/
def run(input: objectiveFunction.Data): GeneralizedLinearModel
def run(input: objectiveFunction.Data): (GeneralizedLinearModel, OptimizationStatesTracker)

/**
* Run the optimization algorithm on the input data, starting from the initial model provided.
Expand All @@ -107,7 +107,7 @@ protected[ml] abstract class GeneralizedLinearOptimizationProblem[Objective <: O
* @param initialModel The initial model from which to begin optimization
* @return The learned GLM for the given optimization problem, data, regularization type, and regularization weight
*/
def run(input: objectiveFunction.Data, initialModel: GeneralizedLinearModel): GeneralizedLinearModel
def run(input: objectiveFunction.Data, initialModel: GeneralizedLinearModel): (GeneralizedLinearModel, OptimizationStatesTracker)

/**
* Compute the regularization term value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected[ml] class SingleNodeOptimizationProblem[Objective <: SingleNodeObjecti
* @param input The training data
* @return The learned GLM for the given optimization problem, data, regularization type, and regularization weight
*/
override def run(input: Iterable[LabeledPoint]): GeneralizedLinearModel =
override def run(input: Iterable[LabeledPoint]): (GeneralizedLinearModel, OptimizationStatesTracker) =
run(input, initializeZeroModel(input.head.features.size))

/**
Expand All @@ -87,13 +87,13 @@ protected[ml] class SingleNodeOptimizationProblem[Objective <: SingleNodeObjecti
* @param initialModel The initial model from which to begin optimization
* @return The learned GLM for the given optimization problem, data, regularization type, and regularization weight
*/
override def run(input: Iterable[LabeledPoint], initialModel: GeneralizedLinearModel): GeneralizedLinearModel = {
override def run(input: Iterable[LabeledPoint], initialModel: GeneralizedLinearModel): (GeneralizedLinearModel, OptimizationStatesTracker) = {

val normalizationContext = optimizer.getNormalizationContext
val (optimizedCoefficients, _) = optimizer.optimize(objectiveFunction, initialModel.coefficients.means)(input)
val (optimizedCoefficients, stateTracker, _) = optimizer.optimize(objectiveFunction, initialModel.coefficients.means)(input)
val optimizedVariances = computeVariances(input, optimizedCoefficients)

createModel(normalizationContext, optimizedCoefficients, optimizedVariances)
(createModel(normalizationContext, optimizedCoefficients, optimizedVariances), stateTracker)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ class FixedEffectCoordinateTest {
case Some(initialModel) =>
val rawModel = initialModel.model

doReturn(updatedModel).when(optimizationProblem).runWithSampling(labeledPoints, rawModel)
doReturn((updatedModel, statesTracker)).when(optimizationProblem).runWithSampling(labeledPoints, rawModel)

coordinate.trainModel(initialModel)

case None =>
doReturn(updatedModel).when(optimizationProblem).runWithSampling(labeledPoints)
doReturn((updatedModel, statesTracker)).when(optimizationProblem).runWithSampling(labeledPoints)

coordinate.trainModel()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ class DistributedOptimizationProblemTest {
doReturn(broadcastCoefficients).when(sparkContext).broadcast(means)
doReturn(broadcastNormalization).when(optimizer).getNormalizationContext
doReturn(normalization).when(broadcastNormalization).value
doReturn((means, None)).when(optimizer).optimize(objectiveFunction, means)(trainingData)
doReturn(statesTracker).when(optimizer).getStateTracker
doReturn(Array(state)).when(statesTracker).getTrackedStates
doReturn((means, statesTracker, None)).when(optimizer).optimize(objectiveFunction, means)(trainingData)
doReturn(Array(state)).when(statesTracker).getTrackedStates
doReturn(means).when(state).coefficients
doReturn(coefficients).when(initialModel).coefficients
doReturn(means).when(coefficients).means
Expand All @@ -154,7 +153,7 @@ class DistributedOptimizationProblemTest {
NoRegularizationContext,
VarianceComputationType.NONE)

val model = problem.run(trainingData, initialModel)
val (model, _) = problem.run(trainingData, initialModel)

assertTrue(means.eq(model.coefficients.means))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ object GeneralizedLinearOptimizationProblemTest {
VarianceComputationType.NONE) {

private val mockGLM = mock(classOf[GeneralizedLinearModel])
private val mockStateTracker = mock(classOf[OptimizationStatesTracker])

//
// Public versions of protected methods for testing
Expand Down Expand Up @@ -242,13 +243,13 @@ object GeneralizedLinearOptimizationProblemTest {
/**
* Unused - needs definition for testing.
*/
override def run(input: Iterable[LabeledPoint]): GeneralizedLinearModel = mockGLM
override def run(input: Iterable[LabeledPoint]): (GeneralizedLinearModel, OptimizationStatesTracker) = (mockGLM, mockStateTracker)

/**
* Unused - needs definition for testing.
*/
override def run(input: Iterable[LabeledPoint], initialModel: GeneralizedLinearModel): GeneralizedLinearModel =
mockGLM
override def run(input: Iterable[LabeledPoint], initialModel: GeneralizedLinearModel): (GeneralizedLinearModel, OptimizationStatesTracker) =
(mockGLM, mockStateTracker)
}

// No way to pass Mixin class type to Mockito, need to define a concrete class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ class SingleNodeOptimizationProblemTest {

doReturn(broadcastNormalization).when(optimizer).getNormalizationContext
doReturn(normalization).when(broadcastNormalization).value
doReturn((means, None)).when(optimizer).optimize(objectiveFunction, means)(trainingData)
doReturn(statesTracker).when(optimizer).getStateTracker
doReturn((means, statesTracker, None)).when(optimizer).optimize(objectiveFunction, means)(trainingData)
doReturn(Array(state)).when(statesTracker).getTrackedStates
doReturn(means).when(state).coefficients
doReturn(coefficients).when(initialModel).coefficients
Expand All @@ -136,7 +135,7 @@ class SingleNodeOptimizationProblemTest {
LogisticRegressionModel.apply,
VarianceComputationType.NONE)

val model = problem.run(trainingData, initialModel)
val (model, _) = problem.run(trainingData, initialModel)

assertTrue(means.eq(model.coefficients.means))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ abstract class Optimizer[-Function <: ObjectiveFunction](
protected[ml] def optimize(
objectiveFunction: Function,
initialCoefficients: Vector[Double])(
data: objectiveFunction.Data): (Vector[Double], Double) = {
data: objectiveFunction.Data): (Vector[Double], OptimizationStatesTracker, Double) = {

val normalizedInitialCoefficients = normalizationContext.value.modelToTransformedSpace(initialCoefficients)

Expand All @@ -183,7 +183,7 @@ abstract class Optimizer[-Function <: ObjectiveFunction](

statesTracker.convergenceReason = getConvergenceReason
val currState = getCurrentState.get
(currState.coefficients, currState.loss)
(currState.coefficients, statesTracker, currState.loss)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class LBFGSBTest {
(lowerBounds(pos) + upperBounds(pos)) / 2
}
.toArray
val (actualCoef, actualValue) = lbfgsb.optimize(objective, DenseVector[Double](initialCoefficients))(trainingData)
val (actualCoef, _, actualValue) = lbfgsb.optimize(objective, DenseVector[Double](initialCoefficients))(trainingData)

Assertions.assertIterableEqualsWithTolerance(actualCoef.toArray, expectedCoef, LBFGSBTest.EPSILON)
assertEquals(actualValue, expectedValue, LBFGSBTest.EPSILON)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LBFGSTest {
val objective = new TestObjective
val trainingData = Array(LabeledPoint(0.0, CommonTestUtils.generateDenseVector(1), 0.0, 0.0))
val initialCoefficients = CommonTestUtils.generateDenseVector(1)
val (actualCoef, actualValue) = lbfgs.optimize(objective, initialCoefficients)(trainingData)
val (actualCoef, _, actualValue) = lbfgs.optimize(objective, initialCoefficients)(trainingData)

Assertions.assertIterableEqualsWithTolerance(actualCoef.toArray, Array(TestObjective.CENTROID), LBFGSTest.EPSILON)
assertEquals(actualValue, LBFGSTest.EXPECTED_LOSS, LBFGSTest.EPSILON)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class OWLQNTest {
val objective = new TestObjective
val trainingData = Array(LabeledPoint(0.0, CommonTestUtils.generateDenseVector(expectedCoef.length), 0.0, 0.0))
val initialCoefficients = CommonTestUtils.generateDenseVector(expectedCoef.length)
val (actualCoef, actualValue) = owlqn.optimize(objective, initialCoefficients)(trainingData)
val (actualCoef, _, actualValue) = owlqn.optimize(objective, initialCoefficients)(trainingData)

Assertions.assertIterableEqualsWithTolerance(actualCoef.toArray, expectedCoef, OWLQNTest.EPSILON)
Assert.assertEquals(actualValue, expectedValue, OWLQNTest.EPSILON)
Expand Down