-
Notifications
You must be signed in to change notification settings - Fork 3
Stochastic Dynamic Programming in Java
We illustrate how to implement forward recursion in Java by relying on modelling constructs introduced in Java 8.
In order to implement a compact forward recursion algorithm we will rely on lambda calculus, functional interfaces, streams, and MapReduce.
A tutorial on Java lambda calculus and functional interfaces can be accessed here.
Tutorials on Java collections, streams, and MapReduce operations can be accessed here.
We discuss two Stochastic Dynamic Programming examples implemented in Java, the relevant code can be found in package jsdp.app.standalone.stochastic; please note that these examples do not rely on the jsdp library; rather, they illustrate core modelling constructs and abstractions which jsdp is built upon.
The same technique can be employed to model and solve Deterministic Dynamic Programs; an application to the Knapsack problem is illustrated in jsdp.app.standalone.deterministic.
Consider a 3-period inventory control problem. At the beginning of each period the firm should decide how many units of a product should be produced. If production takes place for x units, where x > 0, we incur a production cost c(x). This cost comprises both a fix and a variable component: c(x) = 0, if x = 0; c(x) = 3+2x, otherwise.
Production in each period cannot exceed 4 units. Demand in each period takes two possible values: 1 or 2 units with equal probability (0.5). Demand is observed in each period only after production has occurred. After meeting current period's demand holding cost of $1 per unit is incurred for any item that is carried over from one period to the next. Because of limited capacity the inventory at the end of each period cannot exceed 3 units. All demand should be met on time (no backorders). If at the end of the planning horizon (i.e. period 3) the firm still has units in stock, these can be salvaged at $2 per unit. The initial inventory is 1 unit.
We introduce a class to capture system states, as well as functional interfaces to capture actions, state transitions, and costs associated with a given state.
We will rely on the following libraries
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.DoubleStream;
In file InventoryControl.java we create the following class
public class InventoryControl {
int planningHorizon;
double[][] pmf;
...
}
with member variables planningHorizon, which denotes the number of periods in the planning horizon, and pmf a two dimensional array that records a given probability mass function describing random demand in each period.
We define the following constructor
public class InventoryControl {
...
public InventoryControl(int planningHorizon,
double[][] pmf) {
this.planningHorizon = planningHorizon;
this.pmf = pmf;
}
...
}
and we introduce a nested class to model the state of the system
public class InventoryControl {
...
class State{
int period;
int initialInventory;
public State(int period, int initialInventory){
this.period = period;
this.initialInventory = initialInventory;
}
public double[] getFeasibleActions(){
return actionGenerator.apply(this);
}
@Override
public int hashCode(){
String hash = "";
hash = (hash + period) + "_" + this.initialInventory;
return hash.hashCode();
}
@Override
public boolean equals(Object o){
if(o instanceof State)
return ((State) o).period == this.period &&
((State) o).initialInventory == this.initialInventory;
else
return false;
}
@Override
public String toString(){
return this.period + " " + this.initialInventory;
}
}
...
}
method hashCode() is needed because we will store states in hashtables, which require each state to be uniquely identified by a hashcode for direct indexing; method getFeasibleActions() relies on actionGenerator, a function defined as follows
public class InventoryControl {
...
Function<State, double[]> actionGenerator;
...
}
One should recall that for each state, we must be able to generate all feasible actions. For the moment, we leave actionGenerator unimplemented. We will later define an appropriate lambda expression that returns the appropriate set of actions for each relevant state.
In addition to the above functional interface we also define
public class InventoryControl {
...
@FunctionalInterface
interface StateTransitionFunction <S, A, R> {
public S apply (S s, A a, R r);
}
public StateTransitionFunction<State, Double, Double> stateTransition;
@FunctionalInterface
interface ImmediateValueFunction <S, A, R, V> {
public V apply (S s, A a, R r);
}
public ImmediateValueFunction<State, Double, Double, Double> immediateValueFunction;
...
}
capturing
- the state transition function, a function that, given a state, an action, and a random outcome, returns the associated future state; and
- the immediate value function, a function that, given a state, an action, and a random outcome, returns the associated immediate cost/profit.
We have now defined all relevant constructs that are necessary to set up our forward recursion procedure
public class InventoryControl {
...
Map<State, Double> cacheActions = new HashMap<>();
Map<State, Double> cacheValueFunction = new HashMap<>();
double f(State state){
return cacheValueFunction.computeIfAbsent(state, s -> {
double val= Arrays.stream(s.getFeasibleActions())
.map(orderQty -> Arrays.stream(pmf)
.mapToDouble(p -> p[1]*immediateValueFunction.apply(s, orderQty, p[0])+
(s.period < this.planningHorizon ?
p[1]*f(stateTransition.apply(s, orderQty, p[0])) : 0))
.sum())
.min()
.getAsDouble();
double bestOrderQty = Arrays.stream(s.getFeasibleActions())
.filter(orderQty -> Arrays.stream(pmf)
.mapToDouble(p -> p[1]*immediateValueFunction.apply(s, orderQty, p[0])+
(s.period < this.planningHorizon ?
p[1]*f(stateTransition.apply(s, orderQty, p[0])):0))
.sum() == val)
.findAny()
.getAsDouble();
cacheActions.putIfAbsent(s, bestOrderQty);
return val;
});
}
...
}
the procedure directly implements the functional equation f[t,s]. It relies on memoization - method computeIfAbsent() - to store the value of the functional equation for states that have been already visited. This ensures that states are not processed more than once.
Finally, we define our main method.
public class InventoryControl {
...
public static void main(String [] args){
int planningHorizon = 3; //Planning horizon length
double fixedProductionCost = 3; //Fixed production cost
double perUnitProductionCost = 2;//Per unit production cost
int warehouseCapacity = 3; //Production capacity
double holdingCost = 1; //Holding cost
double salvageValue = 2; //Salvage value
int maxOrderQty = 4; //Max order quantity
/**
* Probability mass function: Demand in each period takes two possible values: 1 or 2 units
* with equal probability (0.5).
*/
double pmf[][] = {{1,0.5},{2,0.5}};
int maxDemand = (int) Arrays.stream(pmf).mapToDouble(v -> v[0]).max().getAsDouble();
int minDemand = (int) Arrays.stream(pmf).mapToDouble(v -> v[0]).min().getAsDouble();
...
}
...
}
and we introduce relevant lambda expressions that implement functional interfaces actionGenerator, stateTransition, and immediateValueFunction
public class InventoryControl {
...
public static void main(String [] args){
...
InventoryControl inventory = new InventoryControl(planningHorizon, pmf);
/**
* This function returns the set of actions associated with a given state
*/
inventory.actionGenerator = state ->{
int minQ = Math.max(maxDemand - state.initialInventory, 0);
return DoubleStream.iterate(minQ, orderQty -> orderQty + 1)
.limit(Math.min(maxOrderQty, warehouseCapacity +
minDemand - state.initialInventory - minQ) + 1)
.toArray();
};
/**
* State transition function; given a state, an action and a random outcome, the function
* returns the future state
*/
inventory.stateTransition = (state, action, randomOutcome) ->
inventory.new State(state.period + 1, (int) (state.initialInventory + action - randomOutcome));
/**
* Immediate value function for a given state
*/
inventory.immediateValueFunction = (state, action, demand) -> {
double cost = (action > 0 ? fixedProductionCost + perUnitProductionCost*action : 0);
cost += holdingCost*(state.initialInventory+action-demand);
cost -= (state.period == planningHorizon ? salvageValue : 0)*(state.initialInventory+action-demand);
return cost;
};
...
}
...
}
Finally, we set up the problem parameters and we call relevant methods to obtain an optimal solution
public class InventoryControl {
...
public static void main(String [] args){
...
/**
* Initial problem conditions
*/
int initialPeriod = 1;
int initialInventory = 1;
State initialState = inventory.new State(initialPeriod, initialInventory);
/**
* Run forward recursion and determine the expected total cost of an optimal policy
*/
System.out.println("f_1("+initialInventory+")="+inventory.f(initialState));
/**
* Recover optimal action for period 1 when initial inventory at the beginning of period 1 is 1.
*/
System.out.println("b_1("+initialInventory+")="+inventory.cacheActions.get(inventory.new State(initialPeriod, initialInventory)));
}
...
}
The complete code is available here. After compiling and running the code the output obtained is
f_1(1)=16.25
b_1(1)=3.0
This problem is adapted from W. L. Winston, Operations Research: Applications and Algorithms (7th Edition), Duxbury Press, 2003, chap. 19, example 3.
A gambler has initialWealth. She is allowed to play a game of chance over a given betHorizon, and her goal is to maximize her probability of ending up with a least targetWealth.
If the gambler bets $b on a play of the game, then with probability p, she wins the game and increases her capital position by $b; with probability (1-p), she loses the game and decreases her capital by $b.
On any play of the game, the gambler may not bet more money than she has available.
Determine a betting strategy that will maximize the gambler's probability of attaining a wealth of at least $targetWealth by the end of the betting horizon.
We will rely on the following libraries
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.DoubleStream;
In file GamblersRuin.java we create the following class
public class GamblersRuin {
int betHorizon;
double targetWealth;
double[][] pmf;
...
}
with member variables betHorizon, which denotes the number of games we allow, targetWealth, which denotes our target wealth at the end of the betHorizon, and pmf a two dimensional array that records a given probability mass function describing the probability of winning/losing a game and the associated payoff/loss.
We define the following constructor
public class GamblersRuin {
...
public GamblersRuin(double targetWealth,
int betHorizon,
double[][] pmf) {
this.targetWealth = targetWealth;
this.betHorizon = betHorizon;
this.pmf = pmf;
}
...
}
and we introduce a nested class to model the state of the system
public class GamblersRuin {
...
class State{
int period;
double money;
public State(int period, double money){
this.period = period;
this.money = money;
}
public double[] getFeasibleActions(){
return actionGenerator.apply(this);
}
@Override
public int hashCode(){
String hash = "";
hash = (hash + period) + "_" + money;
return hash.hashCode();
}
@Override
public boolean equals(Object o){
if(o instanceof State)
return ((State) o).period == this.period &&
((State) o).money == this.money;
else
return false;
}
@Override
public String toString(){
return this.period + " " + this.money;
}
...
}
method hashCode() is needed because we will store states in hashtables, which require each state to be uniquely identified by a hashcode for direct indexing; method getFeasibleActions() relies on actionGenerator, a function defined as follows
public class GamblersRuin {
...
Function<State, double[]> actionGenerator;
...
}
One should recall that for each state, we must be able to generate all feasible actions. For the moment, we leave actionGenerator unimplemented. We will later define an appropriate lambda expression that returns the appropriate set of actions for each relevant state.
In addition to the above functional interface we also define
public class GamblersRuin {
...
@FunctionalInterface
interface StateTransitionFunction <S, A, R> {
public S apply (S s, A a, R r);
}
public StateTransitionFunction<State, Double, Double> stateTransition;
@FunctionalInterface
interface ImmediateValueFunction <S, A, R, V> {
public V apply (S s, A a, R r);
}
public ImmediateValueFunction<State, Double, Double, Double> immediateValueFunction;
...
}
capturing
- the state transition function, a function that, given a state, an action, and a random outcome, returns the associated future state; and
- the immediate value function, a function that, given a state, an action, and a random outcome, returns the associated immediate cost/profit.
We have now defined all relevant constructs that are necessary to set up our forward recursion procedure
public class GamblersRuin {
...
Map<State, Double> cacheActions = new HashMap<>();
Map<State, Double> cacheValueFunction = new HashMap<>();
double f(State state){
return cacheValueFunction.computeIfAbsent(state, s -> {
double val= Arrays.stream(s.getFeasibleActions())
.map(bet -> Arrays.stream(pmf)
.mapToDouble(p -> p[1]*immediateValueFunction.apply(s, bet, p[0])+
(s.period < this.betHorizon ?
p[1]*f(stateTransition.apply(s, bet, p[0])) : 0))
.sum())
.max()
.getAsDouble();
double bestBet = Arrays.stream(s.getFeasibleActions())
.filter(bet -> Arrays.stream(pmf)
.mapToDouble(p -> p[1]*immediateValueFunction.apply(s, bet, p[0])+
(s.period < this.betHorizon ?
p[1]*f(stateTransition.apply(s, bet, p[0])) : 0))
.sum() == val)
.findAny()
.getAsDouble();
cacheActions.putIfAbsent(s, bestBet);
return val;
});
}
...
}
the procedure directly implements the functional equation f[t,s]. It relies on memoization - method computeIfAbsent() - to store the value of the functional equation for states that have been already visited. This ensures that states are not processed more than once.
Finally, we define our main method.
public class GamblersRuin {
...
public static void main(String [] args){
int bettingHorizon = 4; //Planning horizon length
double targetWealth = 6; //Target wealth
/**
* Probability mass function: with probability 0.6 we lose the bet amount (multiplier is 0)
* with probability 0.4 we double the bet amount (multiplier is 2)
*/
double pmf[][] = {{0,0.6},{2,0.4}};
...
}
...
}
and we introduce relevant lambda expressions that implement functional interfaces actionGenerator, stateTransition, and immediateValueFunction
public class GamblersRuin {
...
public static void main(String [] args){
...
GamblersRuin ruin = new GamblersRuin(targetWealth, bettingHorizon, pmf);
/**
* This function returns the set of actions associated with a given state
*/
ruin.actionGenerator = state ->{
return DoubleStream.iterate(0, bet -> bet + 1)
.limit((int) Math.ceil(Math.min(targetWealth/2, state.money + 1)))
.toArray();
};
/**
* State transition function; given a state, an action and a random outcome, the function
* returns the future state
*/
ruin.stateTransition = (state, action, randomOutcome) ->
ruin.new State(state.period + 1, state.money - action + action*randomOutcome);
/**
* Immediate value function for a given state
*/
ruin.immediateValueFunction = (state, action, randomOutcome) -> {
if(state.period == ruin.betHorizon)
return state.money - action + action*randomOutcome >= ruin.targetWealth ? 1.0 : 0.0;
else
return 0.0;
};
...
}
...
}
Finally, we set up the problem parameters and we call relevant methods to obtain an optimal solution
public class GamblersRuin {
...
public static void main(String [] args){
...
/**
* Initial problem conditions
*/
int initialPeriod = 1;
double initialWealth = 2;
State initialState = ruin.new State(initialPeriod, initialWealth);
/**
* Run forward recursion and determine the probability of achieving the target wealth when
* one follows an optimal policy
*/
System.out.println("f_1(2)="+ruin.f(initialState));
/**
* Recover optimal action for period 2 when initial wealth at the beginning of period 2 is $1.
*/
System.out.println("b_2(1)="+ruin.cacheActions.get(ruin.new State(2, 1)));
}
...
}
The complete code is available here. After compiling and running the code the output obtained is
f_1(2)=0.19840000000000002
b_2(1)=1.0