Hi, Nice repo! Is this line a bug? Since I think `batch['images']` is `N x B x H x W x C`, so the indices should be shift up by 1. https://github.com/matthias-wright/flaxmodels/blob/edc6a8571a6d7202bd9f3bc9241221405c083fd4/training/stylegan2/training.py#L239