Skip to content

Post-training error for SiT #3

@GpenAI28

Description

@GpenAI28

Dear Author,
I noticed a discrepancy between the actual code implementation and the descriptions in your paper.

The paper claims to perform continued post-training on SiT, REPA, and RAE. However, during my testing of the SIT post-training, I encountered an error after downloading the official pre-trained weights from the SIT repository (https://github.com/willisma/SiT)[SiT-XL-2-256.pt] to use as the starting point. Upon inspection, I found that the official pre-trained weights do not contain the keys expected by Mixflow. In addition, we identified a clear definition error in the Mixflow-SiT training code. Specifically, opt (line 163 in train.py) is referenced before it is defined.

Could you please clarify this issue?

Bash:

torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py \
  --data-path "dataset/imagenet/train" \
  --results-dir results/mixflow-SiT \
  --model SiT-XL/2 \
  --image-size 256 \
  --global-batch-size 32 \
  --ckpt models/sit/SiT-XL-2-256.pt \
  --wandb

The keys that Mixflow expects:

if args.ckpt is not None:
   ckpt_path = args.ckpt
   state_dict = find_model(ckpt_path)
   model.load_state_dict(state_dict["model"])
   ema.load_state_dict(state_dict["ema"])
   opt.load_state_dict(state_dict["opt"])
   args = state_dict["args"]

Error:

[rank0]:     model.load_state_dict(state_dict["model"])
[rank0]: KeyError: 'model'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions