Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/scan/test_scan_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,38 @@ def test_scan_layers_cache_non_pure(self):
# Check that the cache is not populated.
self.assertEqual(len(scan_layers_module._ONE_LAYER_CACHE), 0)

@parameterized.parameters(False, True)
def test_no_weights_layers(self, is_layer_pure: bool):
"""Test that scan_layers works with layers that have no parameters
or buffers (pure compute modules)."""

class PureComputeModule(torch.nn.Module):

def forward(self, x):
return x * 2 + 1

layers = [PureComputeModule().to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)
torch_xla.sync(wait=True)

layers_for_scan = deepcopy(layers)
layers_for_loop = deepcopy(layers)
torch_xla.sync()

output = scan_layers(
layers_for_scan, input_data.clone(), is_layer_pure=is_layer_pure)
self.assert_while_found_in_hlo(output)
torch_xla.sync()

# Test that the result is the same as for loop.
loop_output = input_data.clone()
for layer in layers_for_loop:
loop_output = layer(loop_output)
torch_xla.sync()

super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.001)
self.assert_different_tensor(loop_output, output)


if __name__ == '__main__':
test = unittest.main()
Expand Down
46 changes: 35 additions & 11 deletions torch_xla/experimental/scan_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

def _create_or_get_cached_one_layer_fn(first_layer: nn.Module,
partition_fn,
is_layer_pure: bool = False):
is_layer_pure: bool = False,
has_params_or_buffers: bool = True):
cache_key = (id(partition_fn), id(first_layer))
if is_layer_pure and cache_key in _ONE_LAYER_CACHE:
return _ONE_LAYER_CACHE[cache_key]
Expand All @@ -25,13 +26,23 @@ def _create_or_get_cached_one_layer_fn(first_layer: nn.Module,
from copy import deepcopy
example_layer = deepcopy(first_layer)

# Define the function to apply at each step
def one_layer_fn(carry, params_buffers):
# Apply the current layer's weights and biases to the example layer,
# then run the resulting layer.
output = torch.func.functional_call( # type: ignore
example_layer, params_buffers, carry, strict=True)
return output, None
if has_params_or_buffers:

# Define the function to apply at each step
def one_layer_fn(carry, params_buffers):
# Apply the current layer's weights and biases to the example layer,
# then run the resulting layer.
output = torch.func.functional_call( # type: ignore
example_layer, params_buffers, carry, strict=True)
return output, None
else:

# When the layer has no parameters or buffers, we don't need
# functional_call. Just run the layer directly, ignoring the dummy
# tensor passed as xs.
def one_layer_fn(carry, _dummy):
output = example_layer(carry)
return output, None

if is_layer_pure:
# Cache the function for pure layers to avoid recomputing it.
Expand Down Expand Up @@ -114,10 +125,23 @@ def scan_layers(layers: Iterable[torch.nn.Module],
stacked_buffers = tree_map(lambda *tensors: torch.stack(tensors, dim=0),
*buffers_list)

one_layer = _create_or_get_cached_one_layer_fn(first_layer, partition_fn,
is_layer_pure)
num_layers = len(params_and_buffers)
has_params_or_buffers = any(len(d) > 0 for d in (*params_list, *buffers_list))

one_layer = _create_or_get_cached_one_layer_fn(
first_layer,
partition_fn,
is_layer_pure,
has_params_or_buffers=has_params_or_buffers)

if has_params_or_buffers:
stacked_params_buffers = (stacked_params, stacked_buffers)
else:
# When layers have no parameters or buffers, `scan` still needs a tensor
# with a leading dimension to determine the number of iterations.
# Provide a dummy tensor of shape (num_layers,) for this purpose.
stacked_params_buffers = torch.zeros(num_layers)

stacked_params_buffers = (stacked_params, stacked_buffers)
final_carry, _ = scan(
one_layer,
input_data,
Expand Down
Loading