Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/claude.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
types: [opened, assigned]
pull_request_review:
types: [submitted]
pull_request_target:
types: [opened, synchronize]
Comment on lines 11 to +13
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — Security: pull_request_target runs in the context of the base branch and has access to repository secrets, unlike the safer pull_request event. While the Claude Code action uses restricted allowed_tools, the pull_request_target trigger is a known attack vector for PRs from untrusted forks — the PR author could craft malicious content in the PR title, body, or diff that influences the Claude agent's behavior in a context where secrets are available.

Consider whether pull_request (which does NOT have access to secrets) is sufficient for code review, or add explicit if: github.event.pull_request.head.repo.full_name == github.repository guards to limit this to same-repo PRs only.

branches: [main]

jobs:
claude:
Expand Down Expand Up @@ -38,7 +41,7 @@ jobs:
# Prompt A workaround for claude code action bug of `Fork` PR
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
PR NUMBER: ${{ github.event.pull_request.number || github.event.issue.number}}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit: Missing space before closing }}:

Suggested change
PR NUMBER: ${{ github.event.pull_request.number || github.event.issue.number}}
PR NUMBER: ${{ github.event.pull_request.number || github.event.issue.number }}


Please review this pull request.

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
111
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: This line (111) appears to be a test artifact from the first commit. It will break TOML parsing since it's not valid TOML syntax. This must be removed before merging.

Suggested change
111

5 changes: 5 additions & 0 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing_extensions import Self

from xtuner.v1.loss.utils import sp_split
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo

from .chunk_loss import ChunkLoss

Expand Down Expand Up @@ -195,6 +196,10 @@ def forward(
else:
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)

# TODO: yanhuida, should be removed
if not isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info = ModelForwardExtraLogInfo(extra_info)
Comment on lines 198 to +201
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: This TODO-guarded workaround wraps extra_info in ModelForwardExtraLogInfo if it isn't already one. Two concerns:

  1. This introduces a circular-style dependency: loss/base_loss_ctx.py now imports from model/utils/misc.py. If extra_info can arrive as a plain dict here, the root cause should be fixed at the source (where extra_info is produced), not patched at the consumer.
  2. The # TODO: yanhuida, should be removed comment doesn't explain when or under what condition this should be removed, making it easy to forget. Consider adding a reference to a tracking issue.


extra_info["local_base_loss"] = loss.detach().clone()

# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support
Expand Down
22 changes: 18 additions & 4 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,25 @@ def layers_type(self) -> list[Literal["full_attention", "sliding_attention"]]:
]


class ModelOutputs(TypedDict):
hidden_states: NotRequired[list[torch.Tensor]]
logits: NotRequired[torch.Tensor]
class ModelOutputs(PydanticBaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
hidden_states: list[torch.Tensor] | None = None
logits: torch.Tensor | None = None
loss: torch.Tensor
extra_info: ModelForwardExtraLogInfo
extra_info: ModelForwardExtraLogInfo | None = None

def free(self):
self.hidden_states = None
self.logits = None
self.extra_info = None

# TODO: Only for avoid BC. Should be removed later.
def __getitem__(self, key):
return getattr(self, key)

# TODO: Only for avoid BC. Should be removed later.
def __contains__(self, key):
Comment on lines +213 to +214
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: __contains__ uses model_fields_set which only tracks fields explicitly passed during construction. This works correctly for the current code patterns (where fields are conditionally added to the output dict before construction), but note this is semantically different from checking hasattr or checking if a value is non-None.

For example, MoEModelOutputs(loss=loss, extra_info=info, tokens_per_expert_global=tpe, balancing_loss=None) would report "balancing_loss" in output as True (because it was explicitly passed), even though the value is None. This may be surprising to callers who expect in to indicate a meaningful value is present. Consider documenting this behavior or checking getattr(self, key) is not None instead.

return key in self.model_fields_set


def _is_float8_available():
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info
return ModelOutputs(**output) # type: ignore[typeddict-item]
return ModelOutputs(**output)

def build_embeddings(self, config: TransformerConfig):
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
Expand Down
16 changes: 10 additions & 6 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@


class MoEModelOutputs(ModelOutputs):
router_logits: NotRequired[dict[str, torch.Tensor]]
balancing_loss: NotRequired[torch.Tensor]
z_loss: NotRequired[torch.Tensor]
tokens_per_expert_global: NotRequired[torch.Tensor]
router_logits: dict[str, torch.Tensor] | None = None
balancing_loss: torch.Tensor | None = None
z_loss: torch.Tensor | None = None
tokens_per_expert_global: torch.Tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical: tokens_per_expert_global is declared as required (no default value), but the original TypedDict had it as NotRequired[torch.Tensor].

In _micro_batch_forward (line ~463), tokens_per_expert_global is only added to the output dict inside the if all_router_logits: block. If that condition is false, constructing MoEModelOutputs(**output) will raise a Pydantic ValidationError because this required field is missing.

This should be optional to match the original behavior:

Suggested change
tokens_per_expert_global: torch.Tensor
tokens_per_expert_global: torch.Tensor | None = None


def free(self):
super().free()
self.router_logits = None
Comment on lines +87 to +89
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: MoEModelOutputs.free() only clears router_logits but not balancing_loss, z_loss, or tokens_per_expert_global. For consistency with ModelOutputs.free() (which clears all its tensor fields), all MoE-specific tensor fields should also be freed:

Suggested change
def free(self):
super().free()
self.router_logits = None
def free(self):
super().free()
self.router_logits = None
self.balancing_loss = None
self.z_loss = None
self.tokens_per_expert_global = None



class BalancingLossConfig(PydanticBaseModel):
Expand Down Expand Up @@ -482,7 +486,7 @@ def _micro_batch_forward(

output["router_logits"] = router_logits_dict

return MoEModelOutputs(**output, logits=logits) # type: ignore[typeddict-item]
return MoEModelOutputs(**output, logits=logits)

def _forward(
self,
Expand Down Expand Up @@ -583,7 +587,7 @@ def _forward(
else:
output["router_logits"] = None

return MoEModelOutputs(**output) # type: ignore[typeddict-item]
return MoEModelOutputs(**output)

def build_embeddings(self, config: MoEConfig):
return nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
Expand Down