diff --git a/docs/api.rst b/docs/api.rst index 662c9c8..54d2122 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -229,6 +229,7 @@ Policy Optimization .. autosummary:: clipped_surrogate_pg_loss + cmpo_policy_targets constant_policy_targets dpg_loss entropy_loss @@ -238,6 +239,7 @@ Policy Optimization qpg_loss rm_loss rpg_loss + sampled_cmpo_policy_targets sampled_policy_distillation_loss zero_policy_targets @@ -247,6 +249,18 @@ Clipped Surrogate PG Loss .. autofunction:: clipped_surrogate_pg_loss +CMPO Policy Targets +~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: cmpo_policy_targets + + +Sampled CMPO Policy Targets +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: sampled_cmpo_policy_targets + + Compute Parametric KL Penalty and Dual Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/rlax/__init__.py b/rlax/__init__.py index 5a69b60..57af086 100644 --- a/rlax/__init__.py +++ b/rlax/__init__.py @@ -88,8 +88,10 @@ from rlax._src.policy_gradients import qpg_loss from rlax._src.policy_gradients import rm_loss from rlax._src.policy_gradients import rpg_loss +from rlax._src.policy_targets import cmpo_policy_targets from rlax._src.policy_targets import constant_policy_targets from rlax._src.policy_targets import PolicyTarget +from rlax._src.policy_targets import sampled_cmpo_policy_targets from rlax._src.policy_targets import sampled_policy_distillation_loss from rlax._src.policy_targets import zero_policy_targets from rlax._src.pop_art import art @@ -159,6 +161,7 @@ "categorical_td_learning", "clip_gradient", "clipped_surrogate_pg_loss", + "cmpo_policy_targets", "compose_tx", "conditional_update", "constant_policy_targets", @@ -230,6 +233,7 @@ "rpg_loss", "sample_start_indices", "sampled_policy_distillation_loss", + "sampled_cmpo_policy_targets", "sarsa", "sarsa_lambda", "sigmoid", diff --git a/rlax/_src/policy_targets.py b/rlax/_src/policy_targets.py index 13736d1..05283df 100644 --- a/rlax/_src/policy_targets.py +++ b/rlax/_src/policy_targets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities to construct and learn from policy targets.""" +"""Construct and learn from policy targets. Used by Muesli-based agents.""" import functools @@ -20,6 +20,7 @@ import distrax import jax import jax.numpy as jnp +from rlax._src import base @chex.dataclass(frozen=True) @@ -106,3 +107,190 @@ def sampled_policy_distillation_loss( # We average over the samples, over time and batch, and if the actions are # a continuous vector also over the actions. return -jnp.mean(weights * jnp.maximum(log_probs, min_logp)) + + +def cmpo_policy_targets( + prior_distribution, + embeddings, + rng_key, + baseline_value, + q_provider, + advantage_normalizer, + *, + num_actions, + min_target_advantage=-jnp.inf, + max_target_advantage=1.0, + kl_weight=1.0, +) -> PolicyTarget: + """Policy targets for Clipped MPO. + + The policy targets are in-expectation proportional to: + `prior(a|s) * exp(clip(norm(Q(s, a))))` + + See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. + (https://arxiv.org/pdf/2104.06159.pdf). + + Args: + prior_distribution: the prior policy distribution. + embeddings: embeddings for the `q_provider`. + rng_key: a JAX pseudo random number generator key. + baseline_value: the baseline for `advantage_normalizer`. + q_provider: a fn to compute q values. + advantage_normalizer: a fn to normalise advantages. + *, + num_actions: The total number of discrete actions. + min_target_advantage: The minimum advantage of a policy target. + max_target_advantage: The max advantage of a policy target. + kl_weight: The coefficient for the KL regularizer. + + Returns: + the clipped MPO policy targets. + """ + # Expecting shape [B]. + chex.assert_rank(baseline_value, 1) + rng_key, query_rng_key = jax.random.split(rng_key) + del rng_key + + # Producing all actions with shape [num_actions, B]. + batch_size, = baseline_value.shape + actions = jnp.broadcast_to( + jnp.expand_dims(jnp.arange(num_actions, dtype=jnp.int32), axis=-1), + [num_actions, batch_size]) + + # Using vmap over the num_actions in axis=0. + def _query_q(actions): + return q_provider( + # Using the same rng_key for the all actions samples. + rng_key=query_rng_key, + action=actions, + embeddings=embeddings) + qvalues = jax.vmap(_query_q)(actions) + + # Using the same advantage normalization as for policy gradients. + raw_advantage = advantage_normalizer( + returns=qvalues, baseline_value=baseline_value) + clipped_advantage = jnp.clip( + raw_advantage, min_target_advantage, + max_target_advantage) + + # Construct and normalise the weights. + log_prior = prior_distribution.log_prob(actions) + weights = softmax_policy_target_normalizer( + log_prior + clipped_advantage / kl_weight) + policy_targets = PolicyTarget(actions=actions, weights=weights) + return policy_targets + + +def sampled_cmpo_policy_targets( + prior_distribution, + embeddings, + rng_key, + baseline_value, + q_provider, + advantage_normalizer, + *, + num_actions=2, + min_target_advantage=-jnp.inf, + max_target_advantage=1.0, + kl_weight=1.0, +) -> PolicyTarget: + """Policy targets for sampled CMPO. + + As in CMPO the policy targets are in-expectation proportional to: + `prior(a|s) * exp(clip(norm(Q(s, a))))` + However we only sample a subset of the actions, this allows to scale to + large discrete action spaces and to continuous actions. + + See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. + (https://arxiv.org/pdf/2104.06159.pdf). + + Args: + prior_distribution: the prior policy distribution. + embeddings: embeddings for the `q_provider`. + rng_key: a JAX pseudo random number generator key. + baseline_value: the baseline for `advantage_normalizer`. + q_provider: a fn to compute q values. + advantage_normalizer: a fn to normalise advantages. + *, + num_actions: The number of actions to expand on each step. + min_target_advantage: The minimum advantage of a policy target. + max_target_advantage: The max advantage of a policy target. + kl_weight: The coefficient for the KL regularizer. + + Returns: + the sampled clipped MPO policy targets. + """ + # Expecting shape [B]. + chex.assert_rank(baseline_value, 1) + query_rng_key, action_key = jax.random.split(rng_key) + del rng_key + + # Sampling the actions from the prior. + actions = prior_distribution.sample( + seed=action_key, sample_shape=[num_actions]) + + # Using vmap over the num_expanded in axis=0. + def _query_q(actions): + return q_provider( + # Using the same rng_key for the all actions samples. + rng_key=query_rng_key, + action=actions, + embeddings=embeddings) + qvalues = jax.vmap(_query_q)(actions) + + # Using the same advantage normalization as for policy gradients. + raw_advantage = advantage_normalizer( + returns=qvalues, baseline_value=baseline_value) + clipped_advantage = jnp.clip( + raw_advantage, min_target_advantage, max_target_advantage) + + # The expected normalized weight would be 1.0. The weights would be + # normalized, if the baseline_value is the log of the expected weight. I.e., + # if the baseline_value is log(sum_a(prior(a|s) * exp(Q(s, a)/c))). + weights = jnp.exp(clipped_advantage / kl_weight) + + # The weights are tiled, if using multiple continuous actions. + # It is OK to use multiple continuous actions inside the Q(s, a), + # because the action is sampled from the joint distribution + # and weight is not based on non-joint probabilities. + log_prob = prior_distribution.log_prob(actions) + weights = jnp.broadcast_to( + base.lhs_broadcast(weights, log_prob), log_prob.shape) + return PolicyTarget(actions=actions, weights=weights) + + +def softmax_policy_target_normalizer(log_weights): + """Returns self-normalized weights. + + The self-normalizing weights introduce a significant bias, + if computing the average weight from a small number of samples. + + Args: + log_weights: log unnormalized weights, shape `[num_targets, ...]`. + + Returns: + Weights divided by average weight from sample. Weights sum to `num_targets`. + """ + num_targets = log_weights.shape[0] + return num_targets * jax.nn.softmax(log_weights, axis=0) + + +def loo_policy_target_normalizer(log_weights): + """A leave-one-out normalizer. + + Args: + log_weights: log unnormalized weights, shape `[num_targets, ...]`. + + Returns: + Weights divided by a consistent estimate of the average weight. The weights + are not guaranteed to sum to `num_targets`. + """ + num_targets = log_weights.shape[0] + weights = jnp.exp(log_weights) + # Using a safe consistent estimator of the average weight, independently of + # the numerator. + # The unnormalized weight are already approximately normalized by a + # baseline_value, so we use `1` as the initial estimate of the average weight. + avg_weight = ( + 1 + jnp.sum(weights, axis=0, keepdims=True) - weights) / num_targets + return weights / avg_weight