Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mostlyai/engine/_tabular/argn.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def forward(self, x) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
mask = None
for sub_col in self.cardinalities:
xs = torch.as_tensor(x[sub_col], device=self.device)
xs = torch.nested.to_padded_tensor(xs, padding=-1)
if xs.is_nested:
xs = torch.nested.to_padded_tensor(xs, padding=-1)
mask = (xs != -1).squeeze(-1)
xs = torch.where(xs == -1, torch.tensor(0), xs)
xs = self.get(sub_col)(xs)
Expand Down
46 changes: 35 additions & 11 deletions mostlyai/engine/_tabular/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,20 @@ class BatchCollator:
For sequence data, it will sample subsequences with lengths up to max_sequence_window.
"""

def __init__(self, is_sequential: bool, max_sequence_window: int | None, device: torch.device):
def __init__(
self,
is_sequential: bool,
max_sequence_window: int | None,
device: torch.device,
*,
use_nested_ctxseq: bool = True,
):
self.is_sequential = is_sequential
self.max_sequence_window = max_sequence_window
self.device = device
# Opacus per-sample gradients do not support NestedTensor on CPU/CUDA; use padded
# dense tensors for CTXSEQ when training with DP (see test_tabular_sequential DP path).
self.use_nested_ctxseq = use_nested_ctxseq

def __call__(self, batch: list[dict]) -> dict[str, torch.Tensor]:
batch = pd.DataFrame(batch)
Expand Down Expand Up @@ -177,15 +187,26 @@ def _convert_to_tensors(self, batch: pd.DataFrame) -> dict[str, torch.Tensor]:
dim=-1,
)
elif column.startswith(CTXSEQ):
# construct row tensors and convert the list to nested column tensor
tensors[column] = torch.unsqueeze(
torch.nested.as_nested_tensor(
[torch.tensor(row, dtype=torch.int64, device=self.device) for row in batch[column]],
dtype=torch.int64,
device=self.device,
),
dim=-1,
)
if self.use_nested_ctxseq:
# construct row tensors and convert the list to nested column tensor
tensors[column] = torch.unsqueeze(
torch.nested.as_nested_tensor(
[torch.tensor(row, dtype=torch.int64, device=self.device) for row in batch[column]],
dtype=torch.int64,
device=self.device,
),
dim=-1,
)
else:
# padded batch (variable-length rows); -1 marks padding (matches SequentialContextEmbedders)
tensors[column] = torch.unsqueeze(
torch.tensor(
np.array(list(zip_longest(*batch[column], fillvalue=-1))).T,
dtype=torch.int64,
device=self.device,
),
dim=-1,
)
return tensors

@staticmethod
Expand Down Expand Up @@ -544,7 +565,10 @@ def train(

# and see if it's possible to make it compatible with DP
batch_collator = BatchCollator(
is_sequential=is_sequential, max_sequence_window=max_sequence_window, device=device
is_sequential=is_sequential,
max_sequence_window=max_sequence_window,
device=device,
use_nested_ctxseq=not with_dp,
)
disable_progress_bar()
trn_dataset = load_dataset("parquet", data_files=[str(p) for p in workspace.encoded_data_trn.fetch_all()])[
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"accelerate>=1.5.0",
"peft>=0.18.2", # transformers 5.7+ checks min PEFT in model.add_adapter (integrations/peft.py)
"huggingface-hub[hf-xet]>=0.30.2",
"opacus>=1.5.4",
"opacus>=1.6.0",
"xgrammar>=0.1.32,<1.0.0", # aligned with vllm 0.20
"json-repair>=0.47.0",
"torch>=2.11.0,<2.12.0",
Expand Down
10 changes: 5 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.