Skip to content

Stochastic Dynamic Programming in Java

Roberto Rossi edited this page Feb 6, 2017 · 48 revisions

We illustrate how to implement forward recursion in Java by relying on modelling constructs introduced in Java 8.

Key Java 8 features

In order to implement a compact forward recursion algorithm we will rely on lambda calculus, functional interfaces, streams, and MapReduce.

A tutorial on Java lambda calculus and functional interfaces can be accessed here.

Tutorials on Java collections, streams, and MapReduce operations can be accessed here.

Applications

We discuss two Stochastic Dynamic Programming examples implemented in Java, the relevant code can be found in package jsdp.app.standalone.stochastic; please note that these examples do not rely on the jsdp library; rather, they illustrate core modelling constructs and abstractions which jsdp is built upon.

The same technique can be employed to model and solve Deterministic Dynamic Programs; an application to the Knapsack problem is illustrated in jsdp.app.standalone.deterministic.

Stochastic inventory control

Consider a 3-period inventory control problem. At the beginning of each period the firm should decide how many units of a product should be produced. If production takes place for x units, where x > 0, we incur a production cost c(x). This cost comprises both a fix and a variable component: c(x) = 0, if x = 0; c(x) = 3+2x, otherwise.

Production in each period cannot exceed 4 units. Demand in each period takes two possible values: 1 or 2 units with equal probability (0.5). Demand is observed in each period only after production has occurred. After meeting current period's demand holding cost of $1 per unit is incurred for any item that is carried over from one period to the next. Because of limited capacity the inventory at the end of each period cannot exceed 3 units. All demand should be met on time (no backorders). If at the end of the planning horizon (i.e. period 3) the firm still has units in stock, these can be salvaged at $2 per unit. The initial inventory is 1 unit.

We introduce a class to capture system states, as well as functional interfaces to capture actions, state transitions, and costs associated with a given state.

We will rely on the following libraries

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.DoubleStream;

In file InventoryControl.java we create the following class

public class InventoryControl {
  int planningHorizon;
  double[][] pmf;
  ...
}

with member variables planningHorizon, which denotes the number of periods in the planning horizon, and pmf a two dimensional array that records a given probability mass function describing random demand in each period.

We define the following constructor

public class InventoryControl {
  ...
  public InventoryControl(int planningHorizon,
                          double[][] pmf) {
     this.planningHorizon = planningHorizon;
     this.pmf = pmf;
  }
  ...
}

and we introduce a nested class to model the state of the system

public class InventoryControl {
  ...
  class State{
     int period;
     int initialInventory;

     public State(int period, int initialInventory){
        this.period = period;
        this.initialInventory = initialInventory;
     }

     public double[] getFeasibleActions(){
        return actionGenerator.apply(this);
     }
  
     @Override
     public int hashCode(){
        String hash = "";
        hash = (hash + period) + "_" + this.initialInventory;
        return hash.hashCode();
     }

     @Override
     public boolean equals(Object o){
        if(o instanceof State)
           return  ((State) o).period == this.period &&
                   ((State) o).initialInventory == this.initialInventory;
        else
           return false;
     }

     @Override
     public String toString(){
        return this.period + " " + this.initialInventory;
     }
  }
  ...
}

method hashCode() is needed because we will store states in hashtables, which require each state to be uniquely identified by a hashcode for direct indexing; method getFeasibleActions() relies on actionGenerator, a function defined as follows

public class InventoryControl {
  ...
  Function<State, double[]> actionGenerator;
  ...
}

One should recall that for each state, we must be able to generate all feasible actions. For the moment, we leave actionGenerator unimplemented. We will later define an appropriate lambda expression that returns the appropriate set of actions for each relevant state.

In addition to the above functional interface we also define

public class InventoryControl {
  ...
  @FunctionalInterface
  interface StateTransitionFunction <S, A, R> { 
    public S apply (S s, A a, R r);
  }

  public StateTransitionFunction<State, Double, Double> stateTransition;

  @FunctionalInterface
  interface ImmediateValueFunction <S, A, R, V> { 
     public V apply (S s, A a, R r);
  }

  public ImmediateValueFunction<State, Double, Double, Double> immediateValueFunction;
  ...
}

capturing

  • the state transition function, a function that, given a state, an action, and a random outcome, returns the associated future state; and
  • the immediate value function, a function that, given a state, an action, and a random outcome, returns the associated immediate cost/profit.

We have now defined all relevant constructs that are necessary to set up our forward recursion procedure

public class InventoryControl {
  ...
  Map<State, Double> cacheActions = new HashMap<>();
  Map<State, Double> cacheValueFunction = new HashMap<>();
  double f(State state){
     return cacheValueFunction.computeIfAbsent(state, s -> {
        double val= Arrays.stream(s.getFeasibleActions())
                          .map(orderQty -> Arrays.stream(pmf)
                                                 .mapToDouble(p -> p[1]*immediateValueFunction.apply(s, orderQty, p[0])+
                                                                   (s.period < this.planningHorizon ?
                                                                   p[1]*f(stateTransition.apply(s, orderQty, p[0])) : 0))
                                                 .sum())
                          .min()
                          .getAsDouble();
        double bestOrderQty = Arrays.stream(s.getFeasibleActions())
                                    .filter(orderQty -> Arrays.stream(pmf)
                                                              .mapToDouble(p -> p[1]*immediateValueFunction.apply(s, orderQty, p[0])+
                                                                                (s.period < this.planningHorizon ?
                                                                                p[1]*f(stateTransition.apply(s, orderQty, p[0])):0))
                                                              .sum() == val)
                                    .findAny()
                                    .getAsDouble();
        cacheActions.putIfAbsent(s, bestOrderQty);
        return val;
     });
  }
  ...

}

the procedure directly implements the functional equation f[t,s]. It relies on memoization - method computeIfAbsent() - to store the value of the functional equation for states that have been already visited. This ensures that states are not processed more than once.

Finally, we define our main method.

public class InventoryControl {
  ...
  public static void main(String [] args){
     int planningHorizon = 3;         //Planning horizon length
     double fixedProductionCost = 3;  //Fixed production cost
     double perUnitProductionCost = 2;//Per unit production cost
     int warehouseCapacity = 3;      //Production capacity
     double holdingCost = 1;          //Holding cost
     double salvageValue = 2;         //Salvage value
     int maxOrderQty = 4;             //Max order quantity
     /**
      * Probability mass function: Demand in each period takes two possible values: 1 or 2 units 
      * with equal probability (0.5).
      */
     double pmf[][] = {{1,0.5},{2,0.5}}; 
     int maxDemand =  (int) Arrays.stream(pmf).mapToDouble(v -> v[0]).max().getAsDouble();
     int minDemand =  (int) Arrays.stream(pmf).mapToDouble(v -> v[0]).min().getAsDouble();
     ...
  }
  ...
}

and we introduce relevant lambda expressions that implement functional interfaces actionGenerator, stateTransition, and immediateValueFunction

public class InventoryControl {
  ...
  public static void main(String [] args){
    ...
    InventoryControl inventory = new InventoryControl(planningHorizon, pmf);
  
    /**
     * This function returns the set of actions associated with a given state
     */
    inventory.actionGenerator = state ->{
       int minQ = Math.max(maxDemand - state.initialInventory, 0);
       return DoubleStream.iterate(minQ, orderQty -> orderQty + 1)
                          .limit(Math.min(maxOrderQty, warehouseCapacity +  
                                          minDemand - state.initialInventory - minQ) + 1)
                          .toArray();
    };
  
    /**
     * State transition function; given a state, an action and a random outcome, the function
     * returns the future state
     */
    inventory.stateTransition = (state, action, randomOutcome) -> 
       inventory.new State(state.period + 1, (int) (state.initialInventory + action - randomOutcome));
  
    /**
     * Immediate value function for a given state
     */
    inventory.immediateValueFunction = (state, action, demand) -> {
        double cost = (action > 0 ? fixedProductionCost + perUnitProductionCost*action : 0);
        cost += holdingCost*(state.initialInventory+action-demand);
        cost -= (state.period == planningHorizon ? salvageValue : 0)*(state.initialInventory+action-demand);
        return cost;
    };
    ...      
  }
  ...
}

Finally, we set up the problem parameters and we call relevant methods to obtain an optimal solution

public class InventoryControl {
  ...
  public static void main(String [] args){
     ...
    /**
     * Initial problem conditions
     */
     int initialPeriod = 1;
     int initialInventory = 1;
     State initialState = inventory.new State(initialPeriod, initialInventory);
  
    /**
     * Run forward recursion and determine the expected total cost of an optimal policy
     */
     System.out.println("f_1("+initialInventory+")="+inventory.f(initialState));

    /**
     * Recover optimal action for period 1 when initial inventory at the beginning of period 1 is 1.
     */
     System.out.println("b_1("+initialInventory+")="+inventory.cacheActions.get(inventory.new State(initialPeriod, initialInventory)));
  }  
  ...
}

The complete code is available here. After compiling and running the code the output obtained is

f_1(1)=16.25
b_1(1)=3.0

Gambler's ruin

This problem is adapted from W. L. Winston, Operations Research: Applications and Algorithms (7th Edition), Duxbury Press, 2003, chap. 19, example 3.

A gambler has initialWealth. She is allowed to play a game of chance over a given betHorizon, and her goal is to maximize her probability of ending up with a least targetWealth.

If the gambler bets $b on a play of the game, then with probability p, she wins the game and increases her capital position by $b; with probability (1-p), she loses the game and decreases her capital by $b.

On any play of the game, the gambler may not bet more money than she has available.

Determine a betting strategy that will maximize the gambler's probability of attaining a wealth of at least $targetWealth by the end of the betting horizon.

We will rely on the following libraries

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.DoubleStream;

In file GamblersRuin.java we create the following class

public class GamblersRuin {
  int betHorizon;
  double targetWealth;
  double[][] pmf;
  ...
}

with member variables betHorizon, which denotes the number of games we allow, targetWealth, which denotes our target wealth at the end of the betHorizon, and pmf a two dimensional array that records a given probability mass function describing the probability of winning/losing a game and the associated payoff/loss.

We define the following constructor

public class GamblersRuin {
  ...
  public GamblersRuin(double targetWealth,
                      int betHorizon,
                      double[][] pmf) {
     this.targetWealth = targetWealth;
     this.betHorizon = betHorizon;
     this.pmf = pmf;
  }
  ...
}

and we introduce a nested class to model the state of the system

public class GamblersRuin {
  ...
  class State{
    int period;
    double money;

    public State(int period, double money){
       this.period = period;
       this.money = money;
    }

    public double[] getFeasibleActions(){
       return actionGenerator.apply(this);
    }
  
    @Override
    public int hashCode(){
       String hash = "";
       hash = (hash + period) + "_" + money;
       return hash.hashCode();
    }

    @Override
    public boolean equals(Object o){
       if(o instanceof State)
          return  ((State) o).period == this.period &&
                  ((State) o).money == this.money;
       else
          return false;
    }

    @Override
    public String toString(){
       return this.period + " " + this.money;
    }
    ...
 }

method hashCode() is needed because we will store states in hashtables, which require each state to be uniquely identified by a hashcode for direct indexing; method getFeasibleActions() relies on actionGenerator, a function defined as follows

public class GamblersRuin {
  ...
  Function<State, double[]> actionGenerator;
  ...
}

One should recall that for each state, we must be able to generate all feasible actions. For the moment, we leave actionGenerator unimplemented. We will later define an appropriate lambda expression that returns the appropriate set of actions for each relevant state.

In addition to the above functional interface we also define

public class GamblersRuin {
  ...
  @FunctionalInterface
  interface StateTransitionFunction <S, A, R> { 
    public S apply (S s, A a, R r);
  }

  public StateTransitionFunction<State, Double, Double> stateTransition;

  @FunctionalInterface
  interface ImmediateValueFunction <S, A, R, V> { 
     public V apply (S s, A a, R r);
  }

  public ImmediateValueFunction<State, Double, Double, Double> immediateValueFunction;
  ...

}

capturing

  • the state transition function, a function that, given a state, an action, and a random outcome, returns the associated future state; and
  • the immediate value function, a function that, given a state, an action, and a random outcome, returns the associated immediate cost/profit.

We have now defined all relevant constructs that are necessary to set up our forward recursion procedure

public class GamblersRuin {
  ...
  Map<State, Double> cacheActions = new HashMap<>();
  Map<State, Double> cacheValueFunction = new HashMap<>();
  double f(State state){
     return cacheValueFunction.computeIfAbsent(state, s -> {
        double val= Arrays.stream(s.getFeasibleActions())
                          .map(bet -> Arrays.stream(pmf)
                                            .mapToDouble(p -> p[1]*immediateValueFunction.apply(s, bet, p[0])+
                                                              (s.period < this.betHorizon ?
                                                              p[1]*f(stateTransition.apply(s, bet, p[0])) : 0))
                                            .sum())
                          .max()
                          .getAsDouble();
        double bestBet = Arrays.stream(s.getFeasibleActions())
                               .filter(bet -> Arrays.stream(pmf)
                                                    .mapToDouble(p -> p[1]*immediateValueFunction.apply(s, bet, p[0])+
                                                                      (s.period < this.betHorizon ?
                                                                      p[1]*f(stateTransition.apply(s, bet, p[0])) : 0))
                                                    .sum() == val)
                               .findAny()
                               .getAsDouble();
        cacheActions.putIfAbsent(s, bestBet);
        return val;
     });
  }
  ...

}

the procedure directly implements the functional equation f[t,s]. It relies on memoization - method computeIfAbsent() - to store the value of the functional equation for states that have been already visited. This ensures that states are not processed more than once.

Finally, we define our main method.

public class GamblersRuin {
  ...
  public static void main(String [] args){
    int bettingHorizon = 4;    //Planning horizon length
    double targetWealth = 6;   //Target wealth

    /**
     * Probability mass function: with probability 0.6 we lose the bet amount (multiplier is 0)
     * with probability 0.4 we double the bet amount (multiplier is 2)
     */
    double pmf[][] = {{0,0.6},{2,0.4}}; 
    ...
  }
  ...
}

and we introduce relevant lambda expressions that implement functional interfaces actionGenerator, stateTransition, and immediateValueFunction

public class GamblersRuin {
  ...
  public static void main(String [] args){
    ...
    GamblersRuin ruin = new GamblersRuin(targetWealth, bettingHorizon, pmf);
  
    /**
     * This function returns the set of actions associated with a given state
     */
    ruin.actionGenerator = state ->{
       return DoubleStream.iterate(0, bet -> bet + 1)
                          .limit((int) Math.ceil(Math.min(targetWealth/2, state.money + 1)))
                          .toArray();
    };
  
    /**
     * State transition function; given a state, an action and a random outcome, the function
     * returns the future state
     */
    ruin.stateTransition = (state, action, randomOutcome) -> 
       ruin.new State(state.period + 1, state.money - action + action*randomOutcome);
  
    /**
     * Immediate value function for a given state
     */
    ruin.immediateValueFunction = (state, action, randomOutcome) -> {
          if(state.period == ruin.betHorizon)
             return state.money - action + action*randomOutcome >= ruin.targetWealth ? 1.0 : 0.0;    
          else
             return 0.0;   
       };
    ...      
  }
  ...
}

Finally, we set up the problem parameters and we call relevant methods to obtain an optimal solution

public class GamblersRuin {
  ...
  public static void main(String [] args){
     ...
    /**
     * Initial problem conditions
     */
     int initialPeriod = 1;
     double initialWealth = 2;
     State initialState = ruin.new State(initialPeriod, initialWealth);
  
    /**
     * Run forward recursion and determine the probability of achieving the target wealth when
     * one follows an optimal policy
     */
     System.out.println("f_1(2)="+ruin.f(initialState));

    /**
     * Recover optimal action for period 2 when initial wealth at the beginning of period 2 is $1.
     */
     System.out.println("b_2(1)="+ruin.cacheActions.get(ruin.new State(2, 1)));
  }  
  ...
}

The complete code is available here. After compiling and running the code the output obtained is

f_1(2)=0.19840000000000002
b_2(1)=1.0

Clone this wiki locally