Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 85 additions & 56 deletions hessian/model_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def qvalue(array):
return jnp.linalg.norm(array.reshape(-1))**2 / array.size
return jnp.linalg.norm(array.reshape(-1)) ** 2 / array.size


def cvalue(activations):
Expand All @@ -55,10 +55,9 @@ def tag_qcvalue(module, activations, name):
module.sow('qcvalues', name, qc_value)


def tag_residual_activations(module,
identity_path,
other_path,
name='residual'):
def tag_residual_activations(
module, identity_path, other_path, name='residual'
):
"""Used in inspecting the forward pass and residual networks.

Residual connections involve adding x + F(x) for some function F. This
Expand All @@ -79,15 +78,17 @@ def tag_residual_activations(module,
'qvalues',
name + 'q', # avoid scope collision with the cvalue
jnp.array((res_values_q, add_values_q)),
reduce_fn=lambda x, y: y)
reduce_fn=lambda x, y: y,
)

res_values_c = cvalue(identity_path)
add_values_c = cvalue(other_path)
module.sow(
'cvalues',
name + 'c',
jnp.array((res_values_c, add_values_c)),
reduce_fn=lambda x, y: y)
reduce_fn=lambda x, y: y,
)


def pmap_then_unreplicate(leaf_fn, axis_name='batch'):
Expand All @@ -106,10 +107,12 @@ def pmap_then_unreplicate(leaf_fn, axis_name='batch'):
the values on device 0.
"""
p_leaf_fn = jax.pmap(leaf_fn, axis_name=axis_name)

def tree_fn(*args):
vals = p_leaf_fn(*args)
vals_on_host = jax.tree.map(lambda x: x[0], vals)
return vals_on_host

return tree_fn


Expand Down Expand Up @@ -138,9 +141,9 @@ def append_pytree_leaves(full_pytree, to_append):
return jax.tree.map(lambda x, y: array_append(y, x), to_append, full_pytree)


def create_forward_pass_stats_fn(apply_fn,
capture_activation_norms=False,
sown_collection_names=None):
def create_forward_pass_stats_fn(
apply_fn, capture_activation_norms=False, sown_collection_names=None
):
"""Creates a function which grabs intermediate values from forward pass.

If capture_activation_norms=True then we run the forward pass with
Expand Down Expand Up @@ -168,6 +171,7 @@ def create_forward_pass_stats_fn(apply_fn,
mutables = ['intermediates', 'batch_stats']
if sown_collection_names is not None:
mutables.extend(sown_collection_names)

def get_forward_pass_statistics(params, batch, rng):
_, forward_pass_statistics = apply_fn(
params=params,
Expand All @@ -176,7 +180,8 @@ def get_forward_pass_statistics(params, batch, rng):
rngs={'dropout': rng},
capture_intermediates=capture_activation_norms,
mutable=mutables,
train=True)
train=True,
)
forward_pass_statistics = flax.core.unfreeze(forward_pass_statistics)

# We run in train mode but throw away the updated batch stats because we
Expand All @@ -186,14 +191,17 @@ def get_forward_pass_statistics(params, batch, rng):
if 'intermediates' in forward_pass_statistics:
# This calculation corresponds to the average q-value across the batch.
forward_pass_statistics['intermediate_qvalue'] = jax.tree.map(
qvalue, forward_pass_statistics['intermediates'])
qvalue, forward_pass_statistics['intermediates']
)
forward_pass_statistics['intermediate_cvalue'] = jax.tree.map(
cvalue, forward_pass_statistics['intermediates'])
cvalue, forward_pass_statistics['intermediates']
)

# Don't want to store the full activations.
forward_pass_statistics.pop('intermediates')

return forward_pass_statistics

return get_forward_pass_statistics


Expand All @@ -215,7 +223,9 @@ def remove_leaf_tuples(tree):
def unflatten(d):
return flax.core.freeze(
flax.traverse_util.unflatten_dict(
{tuple(k.split('/')): v for k, v in d.items()}))
{tuple(k.split('/')): v for k, v in d.items()}
)
)


# Not a JAX Array Type - just an empty holder of static information:
Expand Down Expand Up @@ -282,15 +292,17 @@ def build_skip_flags(paths_to_skip):
class ModelDebugger:
"""Debugging tool for internal layers of a model."""

def __init__(self,
forward_pass=None,
grad_fn=None,
use_pmap=True,
save_every=1,
metrics_logger=None,
pytree_metrics_logger=None,
skip_flags=None,
skip_groups=None):
def __init__(
self,
forward_pass=None,
grad_fn=None,
use_pmap=True,
save_every=1,
metrics_logger=None,
pytree_metrics_logger=None,
skip_flags=None,
skip_groups=None,
):
"""Used to inspect a models forward and backward pass.

The following keys are required in config -
Expand All @@ -315,8 +327,10 @@ def __init__(self,
attention layer jacobians contribute to the vanishing gradient problem"?
"""
if metrics_logger and (metrics_logger._json_path is None):
raise ValueError('To use the ModelDebugger with a metrics_logger, a json'
' path must be specified when building metrics_logger')
raise ValueError(
'To use the ModelDebugger with a metrics_logger, a json'
' path must be specified when building metrics_logger'
)
self._save_every = save_every
self._metrics_logger = metrics_logger
self._pytree_metrics_logger = pytree_metrics_logger
Expand Down Expand Up @@ -349,31 +363,36 @@ def __init__(self,
):
self._stored_metrics = pytree_metrics_logger.load_latest_pytree()

def _grab_statistics(self,
step,
batch=None,
rng=None,
params=None,
grad=None,
update=None,
grad_norms_sql2=None,
update_norms_sql2=None,
param_norms_sql2=None):
def _grab_statistics(
self,
step,
batch=None,
rng=None,
params=None,
grad=None,
update=None,
grad_norms_sql2=None,
update_norms_sql2=None,
param_norms_sql2=None,
):
"""Computes layerwise gradient and parameter norm statistics."""
metrics_dict = {'step': step}
if grad is None and grad_norms_sql2 is None and self.grad_fn is not None:
grad = self.grad_fn(params, batch, rng)

for tup in zip([params, grad, update],
[param_norms_sql2, grad_norms_sql2, update_norms_sql2],
['param', 'grad', 'update']):
for tup in zip(
[params, grad, update],
[param_norms_sql2, grad_norms_sql2, update_norms_sql2],
['param', 'grad', 'update'],
):
variable_tree, norm_tree_sql2, key = tup
if variable_tree and not norm_tree_sql2:
norm_tree_sql2 = self._tree_norm_fn_sql2(variable_tree)
if norm_tree_sql2:
metrics_dict['{}_norms_sql2'.format(key)] = norm_tree_sql2
metrics_dict['global_{}_norm_sql2'.format(key)] = sum(
jax.tree.leaves(norm_tree_sql2))
jax.tree.leaves(norm_tree_sql2)
)

return metrics_dict

Expand Down Expand Up @@ -451,18 +470,20 @@ def run_skip_analysis(self, params, batch, rng):
grad_dict['unmodified_gradient'] = self._tree_norm_fn_sql2(new_g)
return {'skip_analysis': grad_dict}

def full_eval(self,
step,
params=None,
grad=None,
update=None,
fwd_pass_summaries=None,
param_norms_sql2=None,
grad_norms_sql2=None,
update_norms_sql2=None,
extra_scalar_metrics=None,
batch=None,
rng=None):
def full_eval(
self,
step,
params=None,
grad=None,
update=None,
fwd_pass_summaries=None,
param_norms_sql2=None,
grad_norms_sql2=None,
update_norms_sql2=None,
extra_scalar_metrics=None,
batch=None,
rng=None,
):
"""Computes statistics of the forward and backward pass and save to disk.

Currently what is written to disk is a pytree, with a dict at the top level
Expand Down Expand Up @@ -511,8 +532,13 @@ def full_eval(self,
"""
all_metrics = {'step': step}
if any([
params, grad, update, param_norms_sql2, grad_norms_sql2,
update_norms_sql2, self.grad_fn
params,
grad,
update,
param_norms_sql2,
grad_norms_sql2,
update_norms_sql2,
self.grad_fn,
]):
metrics_dict = self._grab_statistics(
step=step,
Expand All @@ -523,7 +549,8 @@ def full_eval(self,
update=update,
grad_norms_sql2=grad_norms_sql2,
update_norms_sql2=update_norms_sql2,
param_norms_sql2=param_norms_sql2)
param_norms_sql2=param_norms_sql2,
)
all_metrics.update(metrics_dict)

if extra_scalar_metrics:
Expand All @@ -535,7 +562,8 @@ def full_eval(self,
if self.forward_pass and not fwd_pass_summaries:
if batch is None:
raise ValueError(
'Must supply a batch when computing forward pass stats.')
'Must supply a batch when computing forward pass stats.'
)

fwd_pass_summaries = self.forward_pass(params, batch, rng)

Expand All @@ -554,7 +582,8 @@ def full_eval(self,

self._stored_metrics = append_pytree_leaves(
self._stored_metrics,
remove_leaf_tuples(flax.core.unfreeze(all_metrics)))
remove_leaf_tuples(flax.core.unfreeze(all_metrics)),
)

if self._metrics_logger and jax.process_index() == 0:
self._maybe_save_metrics(step)
Expand Down
31 changes: 15 additions & 16 deletions hessian/model_debugger_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,9 @@
}


def get_grad(params,
batch,
rng,
batch_stats=None,
module_flags=None,
training_cost=None):
def get_grad(
params, batch, rng, batch_stats=None, module_flags=None, training_cost=None
):
"""Single step of the training loop.

Args:
Expand All @@ -57,13 +54,11 @@ def get_grad(params,
kwargs = {'module_flags': module_flags}
else:
kwargs = {}

def opt_cost(params):
return training_cost(
params,
batch,
batch_stats=batch_stats,
dropout_rng=rng,
**kwargs)
params, batch, batch_stats=batch_stats, dropout_rng=rng, **kwargs
)

grad_fn = jax.value_and_grad(opt_cost, has_aux=True)
_, grad = grad_fn(params)
Expand Down Expand Up @@ -104,7 +99,8 @@ def __init__(
get_act_stats_fn = model_debugger.create_forward_pass_stats_fn(
model.apply_on_batch,
capture_activation_norms=True,
sown_collection_names=callback_config.get('sown_collection_names'))
sown_collection_names=callback_config.get('sown_collection_names'),
)
batch_stats = jax.tree.map(lambda x: x[:][0], batch_stats)
grad_fn = functools.partial(
get_grad,
Expand All @@ -117,7 +113,8 @@ def __init__(
metrics_logger=logger,
grad_fn=grad_fn,
skip_flags=callback_config.get('skip_flags'),
skip_groups=callback_config.get('skip_groups'))
skip_groups=callback_config.get('skip_groups'),
)
# pmap functions for the training loop
# in_axes = (params = 0, batch_stats = 0, batch = 0, step = None,
# lr = None, rng = None, local_device_index = 0, training_metrics_grabber=0,
Expand Down Expand Up @@ -150,14 +147,16 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step):
"""
del optimizer_state
del batch_stats
p_norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0].reshape(-1))**2,
params)
p_norms = jax.tree.map(
lambda x: jnp.linalg.norm(x[0].reshape(-1)) ** 2, params
)

self.debugger.full_eval(
step=global_step,
params=params,
param_norms_sql2=p_norms,
batch=self.batch,
rng=self.batch_rng)
rng=self.batch_rng,
)

return {}
Loading