Skip to content

Code vs paper #10

@EelcoHoogendoorn

Description

@EelcoHoogendoorn

Hi all,

Im working on a JAX implementation of a hippo-gated-rnn; I wasnt quite sure how to interpret the diagram in the paper linked below; but indeed I cannot quite mesh it with the torch implementation linked below as well. The code seems more sensible to me than the paper; that is it makes sense to me to have the ssm see the raw data unfiltered; placing the nonlinear gated action in front seems like it might jeopardize the unobstructed flow of gradients along the sequence.

The version from the diagram in the paper works quite alright though, in my use case. Though the torch version seems to converge more quickly. Just curious if im misreading something here, or what your latest thinking on these matters is.

EDIT: I had very good experience with the paper version in terms of avoiding exploding gradients; while the code version seems to converge faster and smoother initially, I do observe the gated unit to be able to explode on longer trajectories. Lots of things to explore here I suppose; deep/stacked ssms with pointwise nonlinearities have not worked for me so far.

class GatedHippo(nn.Module):
	"""
	linear state-space-model coupled with a gated-nonlinear module
	"""
	ssm: nn.Module
	rnn: nn.Module = MGUCell()

	@nn.compact
	def __call__(self, carry, inputs):
		"""How I read the diagram in the paper
	        https://arxiv.org/pdf/2008.07669.pdf
                """
		rnn_carry, rnn_output = self.rnn(carry['rnn'], jnp.concatenate([carry['ssm'], inputs]))
		ssm_carry, ssm_output = self.ssm(carry['ssm'], rnn_output)
		carry = {'rnn': rnn_carry, 'ssm': ssm_carry}
		return carry, rnn_output  # should we use ssm output here, lest it goes unused?

	@nn.compact
	def __call__(self, carry, inputs):
		"""How I read the torch code
		https://github.com/HazyResearch/hippo-code/blob/201148256bd2b71cb07668dc00075420cfd4c567/model/model.py#L79
		"""
		ssm_carry, ssm_output = self.ssm(carry['ssm'], inputs)
		rnn_carry, rnn_output = self.rnn(carry['rnn'], jnp.concatenate([ssm_output, inputs]))
		carry = {'rnn': rnn_carry, 'ssm': ssm_carry}
		return carry, rnn_output

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions