Skip to content

Fix CrossEntropyLoss block to support multi-output models#28232

Open
Rishi-Dave wants to merge 3 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/onnxblock-ce-loss-multi-output
Open

Fix CrossEntropyLoss block to support multi-output models#28232
Rishi-Dave wants to merge 3 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/onnxblock-ce-loss-multi-output

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • artifacts.generate_artifacts(..., loss=LossType.CrossEntropyLoss) no longer aborts with i < node_->OutputDefs().size() when the base model has multi-dimensional outputs.
  • The SoftmaxCrossEntropyLoss op produces two outputs (loss, log_prob); the second was being dropped by graph optimizers because it had no value_info entry, leaving the gradient builder to dereference a missing output def via O(1).

Motivation

Fixes #22465. Users hit a hard C++ assertion when training models like DistilBERT whose forward graph emits a multi-dimensional last-hidden-state tensor. The same pattern appears for any seq2seq / LM training setup that pipes a 3-D output into CrossEntropyLoss.

This is a Python-only change scoped to the onnxblock training-artifacts API; the core inference engine is unaffected.

Changes

  • orttraining/orttraining/python/training/onnxblock/loss/loss.py — after appending the SoftmaxCrossEntropyLoss node, register a value_info entry for log_prob_output_name so its output def survives shape inference and graph cleanup. Idempotent — guarded against duplicate entries.
  • orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py — new test_crossentropy_loss_multi_output_model builds a 3-D output toy model, calls generate_artifacts with LossType.CrossEntropyLoss, and asserts the saved training_model.onnx retains both outputs on the SCE node.

Test Plan

  • New test exercises the previously-failing path:
    python -m pytest orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py::test_crossentropy_loss_multi_output_model -v
  • Existing CE coverage:
    python -m pytest orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py -k crossentropy -v
  • lintrunner clean on the diff.

Fixes #22465

…put models

CrossEntropyLoss.build() created a SoftmaxCrossEntropyLoss node with two
outputs (loss, log_prob) but never registered log_prob in model.graph.value_info.
Graph optimizers then dropped the output def, causing the gradient builder to
hit a C++ assertion (i < node_->OutputDefs().size()) via O(1) when generating
training artifacts for models with multi-dimensional outputs (e.g. seq2seq).

Fix: after appending the node, add a value_info entry for log_prob_output_name
using the same elem_type as the input scores tensor. A guard prevents duplicate
entries if build() is called more than once. This keeps the output def alive
through graph cleanup without changing the user-visible API (the block still
returns only loss_node_output_name).

Fixes microsoft#22465
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes onnxblock training-artifacts generation for LossType.CrossEntropyLoss when the base model output is multi-dimensional by ensuring the SoftmaxCrossEntropyLoss node’s second output (log_prob) is preserved through optimization/shape-inference.

Changes:

  • Add a value_info entry for SoftmaxCrossEntropyLoss’s log_prob output to prevent it from being dropped during graph optimization.
  • Add a regression test that exports a toy seq2seq-style (3-D output) model and verifies generate_artifacts succeeds and the saved training model retains both SCE outputs.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
orttraining/orttraining/python/training/onnxblock/loss/loss.py Registers log_prob in value_info after adding SoftmaxCrossEntropyLoss to avoid optimizer pruning breaking gradient building.
orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py Adds regression test covering multi-dimensional model output + CrossEntropyLoss artifact generation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +122 to +126
scores_info = _graph_utils.get_output_from_output_name(self.base, scores_input_name)
scores_elem_type = scores_info.type.tensor_type.elem_type
if not any(vi.name == log_prob_output_name for vi in self.base.graph.value_info):
self.base.graph.value_info.append(
onnx.helper.make_tensor_value_info(log_prob_output_name, scores_elem_type, None)
Rishi-Dave added 2 commits May 3, 2026 12:45
get_output_from_output_name only searches graph.output, causing a
LookupError when scores_input_name is an intermediate tensor (not yet
a graph output). Add get_value_info_for_name to _graph_utils.py that
searches graph.output -> graph.input -> graph.value_info in order, and
use it in CrossEntropyLoss.build to resolve scores_elem_type, restoring
support for intermediate-tensor inputs.
…ntropyLoss

Line 92 of loss.py used get_output_from_output_name to derive the
labels_input shape from scores_input_name. Like the scores path fixed
previously, this raises LookupError whenever scores_input_name is an
intermediate tensor rather than a declared graph output. Switch to
get_value_info_for_name so both call sites handle all tensor sources
consistently.

Also trim get_value_info_for_name's docstring to a one-liner and drop
the single quotes in its LookupError message to match sibling helpers.
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the catch. Pushed 122e5bbcd5 (with prerequisite 1ee5f8248a) to address the regression around the unconditional get_output_from_output_name call.

Changes:

  • Added _graph_utils.get_value_info_for_name(model, name) that resolves a ValueInfoProto by searching graph.outputgraph.inputgraph.value_info in order, raising LookupError only if none match.
  • CrossEntropyLoss.build now uses the new helper at both call sites: the scores_elem_type lookup for the new log_prob value_info entry, and the earlier labels_input = copy.deepcopy(...) path. This restores the documented "output so far" behavior where loss_input_name may be an intermediate tensor with type info in value_info.
  • The existing multi-output regression test still covers the original reported failure; both call sites now share a single resolution strategy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Training] Error building gradient graph for bert models for on-device training

2 participants