Skip to content

StyleGAN2 model broken in JAX v0.4.36 and newer #41

@Skylion007

Description

@Skylion007
Downloading: "https://www.dropbox.com/s/e8de1peq7p8gq9d/stylegan2_generator_ffhq.h5" to /tmp/flaxmodels/stylegan2_generator_ffhq.h5

100%|██████████| 133M/133M [00:02<00:00, 59.2MiB/s]

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[/tmp/ipython-input-4126621917.py](https://localhost:8080/#) in <cell line: 0>()
     13 # ['afhqcat', 'afhqdog', 'afhqwild', 'brecahad', 'car', 'cat', 'church', 'cifar10', 'ffhq', 'horse', 'metfaces']
     14 generator = fm.stylegan2.Generator(pretrained='ffhq')
---> 15 params = generator.init(key, z)
     16 images = generator.apply(params, z, train=False)
     17 

    [... skipping hidden 9 frame]

8 frames

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train, noise_mode, rng)
    699                                      name='mapping_network')(z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train)
    700 
--> 701         x = SynthesisNetwork(resolution=self.resolution_,
    702                              num_channels=self.num_channels,
    703                              w_dim=self.w_dim,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, dlatents_in, noise_mode, rng)
    562         for res in range(2, resolution_log2 + 1):
    563             init_rng, init_key = random.split(init_rng)
--> 564             x, y = SynthesisBlock(fmaps=nf(res - 1),
    565                                   res=res,
    566                                   num_layers=1 if res == 2 else 2,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, x, y, dlatents_in, noise_mode, rng)
    433         for i in range(self.num_layers):
    434             init_rng, init_key = random.split(init_rng)
--> 435             x = SynthesisLayer(fmaps=self.fmaps, 
    436                                kernel=3,
    437                                layer_idx=self.res * 2 - (5 - i) if self.res > 2 else 0,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, x, dlatents, noise_mode, rng)
    287         b = ops.equalize_lr_bias(b, self.lr_multiplier)
    288 
--> 289         x = ops.modulated_conv2d_layer(x=x, 
    290                                        w=w,
    291                                        s=s,

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in modulated_conv2d_layer(x, w, s, fmaps, kernel, up, down, demodulate, resample_kernel, fused_modconv)
    459 
    460     # 2D convolution.
--> 461     x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel)
    462 
    463     # Reshape/scale output.

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in conv2d(x, w, up, down, resample_kernel, padding)
    415     w = w.astype(x.dtype)
    416     if up:
--> 417         x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding)
    418     elif down:
    419         x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding)

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in upsample_conv_2d(x, w, k, factor, gain, padding)
    401     pad0 = (k.shape[0] + factor - cw) // 2 + padding
    402     pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding
--> 403     x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1))
    404     return x
    405 

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in upfirdn2d(x, f, padding, up, down, strides, flip_filter, gain)
    168 
    169     # upsample by inserting zeros
--> 170     x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C))
    171     x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0)))
    172     x = jnp.reshape(x, newshape=(B, H * up, W * up, C))

[/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in reshape(***failed resolving arguments***)
   2024   # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
   2025   if not isinstance(newshape, DeprecatedArg):
-> 2026     raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
   2027                     " Use shape instead.")
   2028   if shape is None:

TypeError: The newshape argument to jnp.reshape was removed in JAX v0.4.36. Use shape instead.

Requires a code-modification here to run flaxmodels with the latest JAX version. Tried using the demo stylegan2 notebook

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions