generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 26
[Contribution] SolarOpenForCausalLM Support #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lifelongeeek
wants to merge
10
commits into
aws-neuron:main
Choose a base branch
from
lifelongeeek:feat/solar-open-support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+2,534
−0
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
c72300b
feat: add SolarOpenForCausalLM model implementation
circle-jin 1e56783
feat: add SolarOpen generation demos and accuracy tests
circle-jin 3ffab2b
docs: add SolarOpen implementation and testing documentation
circle-jin 8dcaa42
fix: remove undefined tensor_capture_hook from model_inputs in hf_ada…
lifelongeeek d5849b3
refactor: remove solar_open from src/models and root, moved to contrib
circle-jin a276f2c
feat: add Solar Open 100B MoE contrib model with tests, examples, and…
circle-jin fda7e0f
fix: resolve integration test failures for Solar Open MoE
circle-jin bfcb2b3
docs: correct 'Solar Open not in transformers' comments
circle-jin cd93feb
refactor(solar_open): migrate to transformers 5.0.0 SolarOpenForCausalLM
circle-jin 25c6d95
refactor(solar_open): scope PR to contrib/models/solar_open only
circle-jin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # Contrib Model: Solar Open 100B MoE | ||
|
|
||
| NeuronX Distributed Inference implementation of [upstage/Solar-Open-100B](https://huggingface.co/upstage/Solar-Open-100B), a 100B Mixture-of-Experts language model. | ||
|
|
||
| ## Model Information | ||
|
|
||
| - **HuggingFace ID:** `upstage/Solar-Open-100B` | ||
| - **Model Type:** Decoder-only MoE transformer | ||
| - **Architecture:** 64 routed experts + 1 shared expert per layer, top-2 routing | ||
| - **Parameters:** ~100B total, ~22B active per token | ||
| - **License:** Check HuggingFace model card | ||
|
|
||
| > **Note:** Solar Open is **not** available in the `transformers` library. The model config and weights are loaded directly from the HuggingFace checkpoint using custom loaders (`load_solar_open_config`). | ||
|
|
||
| ## Architecture Details | ||
|
|
||
| Solar Open shares the same MoE routing architecture as GLM-4.5 MoE, with the following key differences: | ||
|
|
||
| | Property | Solar Open | GLM-4.5 MoE | | ||
| |----------|-----------|-------------| | ||
| | `partial_rotary_factor` | 1.0 (full RoPE) | < 1.0 (partial RoPE) | | ||
| | `attention_bias` | False | True | | ||
| | `use_qk_norm` | False | True | | ||
| | `first_k_dense_replace` | **0** (ALL layers MoE) | > 0 (some dense layers) | | ||
| | `rope_scaling` | None or `yarn` | None | | ||
| | In `transformers` | ❌ No | ✅ Yes | | ||
|
|
||
| ### MoE Configuration (100B model) | ||
|
|
||
| - `n_routed_experts`: 64 | ||
| - `n_shared_experts`: 1 | ||
| - `num_experts_per_tok`: 2 (top-2 routing) | ||
| - `n_group`: 8, `topk_group`: 2 | ||
| - `norm_topk_prob`: True | ||
| - `routed_scaling_factor`: 1.0 | ||
| - Router: sigmoid + group-limited routing + `e_score_correction_bias` | ||
|
|
||
| ### Expert Parallelism Limitation | ||
|
|
||
| > ⚠️ **EP (Expert Parallelism) is currently limited to `moe_ep_degree=1`** due to a known issue with the MoE EP group initialization when `n_group > 1`. Use TP-only parallelism for now. | ||
|
|
||
| Recommended production config: `tp_degree=32, moe_tp_degree=4, moe_ep_degree=8` (requires trn2.48xlarge or equivalent). | ||
|
|
||
| ## Hardware Requirements | ||
|
|
||
| | Configuration | Instance | | ||
| |--------------|----------| | ||
| | Development / testing | trn1.32xlarge (32 NeuronCores) | | ||
| | Production (100B, seq_len=65536) | trn2.48xlarge (128 NeuronCores) | | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| import sys | ||
| sys.path.insert(0, "contrib/models/solar_open/src") | ||
|
|
||
| import torch | ||
| from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig | ||
| from solar_open.modeling_solar_open import ( | ||
| SolarOpenInferenceConfig, | ||
| NeuronSolarOpenForCausalLM, | ||
| load_solar_open_config, | ||
| ) | ||
|
|
||
| model_path = "/path/to/upstage/Solar-Open-100B" | ||
| traced_model_path = "/path/to/traced_model" | ||
|
|
||
| neuron_config = MoENeuronConfig( | ||
| tp_degree=32, | ||
| moe_tp_degree=4, | ||
| moe_ep_degree=8, | ||
| batch_size=4, | ||
| seq_len=65536, | ||
| torch_dtype=torch.bfloat16, | ||
| on_device_sampling_config=OnDeviceSamplingConfig( | ||
| do_sample=True, temperature=0.6, top_k=20, top_p=0.95 | ||
| ), | ||
| fused_qkv=True, | ||
| qkv_kernel_enabled=True, | ||
| attn_kernel_enabled=True, | ||
| ) | ||
|
|
||
| config = SolarOpenInferenceConfig( | ||
| neuron_config, | ||
| load_config=load_solar_open_config(model_path), | ||
| ) | ||
|
|
||
| # Compile | ||
| model = NeuronSolarOpenForCausalLM(model_path, config) | ||
| model.compile(traced_model_path) | ||
|
|
||
| # Load and run | ||
| model = NeuronSolarOpenForCausalLM(traced_model_path) | ||
| model.load(traced_model_path) | ||
| ``` | ||
|
|
||
| See `examples/generation_solar_open_demo.py` for a full end-to-end example, or `../../examples/generation_solar_open.py` for the production benchmark script. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unavailable file - ../../examples/generation_solar_open.py |
||
|
|
||
| ## Testing | ||
|
|
||
| ### Unit Tests (CPU, no Neuron hardware required) | ||
|
|
||
| ```bash | ||
| cd contrib/models/solar_open | ||
| source /path/to/neuronx_venv/bin/activate | ||
| python -m pytest test/unit/ -v | ||
| ``` | ||
|
|
||
| ### Integration Tests (requires Neuron hardware) | ||
|
|
||
| ```bash | ||
| cd contrib/models/solar_open | ||
| python -m pytest test/integration/ -v --capture=tee-sys | ||
| ``` | ||
|
|
||
| Integration tests compile a 2-layer tiny random model and verify: | ||
| 1. **Smoke test** — model compiles and loads without error | ||
| 2. **Output shape** — generated token IDs have correct shape | ||
| 3. **Determinism** — same input produces same output across runs | ||
|
|
||
| ## Compatibility Matrix | ||
|
|
||
| | Instance | NxDI Version | Status | | ||
| |----------|-------------|--------| | ||
| | trn1.32xlarge | 2.20+ | ✅ Validated (unit tests) | | ||
| | trn2.48xlarge | 2.20+ | 🔧 Integration pending | | ||
| | Inf2 | Any | Not tested | | ||
|
|
||
| ## Maintainer | ||
|
|
||
| Contributed by: gmkim (lifelongeeek) | ||
|
|
||
| **Last Updated:** 2026-03-06 | ||
233 changes: 233 additions & 0 deletions
233
contrib/models/solar_open/examples/generation_solar_open_demo.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,233 @@ | ||
| """ | ||
| Solar Open MoE Generation Demo (contrib version). | ||
|
|
||
| This script demonstrates how to compile and run inference with the Solar Open MoE model | ||
| using neuronx-distributed-inference. Solar Open is available in transformers >= 5.0.0. | ||
|
|
||
| Usage: | ||
| # Compile and generate (tiny random model): | ||
| python generation_solar_open_demo.py | ||
|
|
||
| # Skip compile (load from existing traced model): | ||
| python generation_solar_open_demo.py --skip-compile | ||
|
|
||
| # Production Solar Open 100B (trn2.48xlarge recommended): | ||
| python generation_solar_open_demo.py \\ | ||
| --model-path /path/to/upstage/Solar-Open-100B \\ | ||
| --traced-model-path /path/to/Solar-Open-100B-traced \\ | ||
| --tp-degree 32 \\ | ||
| --seq-len 65536 | ||
| """ | ||
|
|
||
| import argparse | ||
| import os | ||
| import shutil | ||
| import sys | ||
| from pathlib import Path | ||
|
|
||
| # Add contrib src to path so we can import solar_open directly | ||
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | ||
|
|
||
| import torch | ||
| from transformers import AutoTokenizer, GenerationConfig | ||
|
|
||
| from neuronx_distributed_inference.models.config import ( | ||
| MoENeuronConfig, | ||
| OnDeviceSamplingConfig, | ||
| ) | ||
| from solar_open.modeling_solar_open import ( | ||
| SolarOpenInferenceConfig, | ||
| NeuronSolarOpenForCausalLM, | ||
| load_solar_open_config, | ||
| ) | ||
| from neuronx_distributed_inference.utils.hf_adapter import ( | ||
| HuggingFaceGenerationAdapter, | ||
| ) | ||
|
|
||
| # Default paths — override via CLI args | ||
| MODEL_PATH = "solar_open_tiny_random" | ||
| TRACED_MODEL_PATH = "solar_open_tiny_random_traced" | ||
|
|
||
| torch.manual_seed(0) | ||
|
|
||
| DTYPE = torch.bfloat16 | ||
|
|
||
|
|
||
| def get_neuron_config(tp_degree: int = 2, seq_len: int = 64) -> MoENeuronConfig: | ||
| """Create MoENeuronConfig for Solar Open. | ||
|
|
||
| Defaults are sized for a 2-core tiny random model. | ||
| For Solar Open 100B on trn2.48xlarge use tp_degree=32, seq_len=65536. | ||
| """ | ||
| return MoENeuronConfig( | ||
| tp_degree=tp_degree, | ||
| moe_tp_degree=min(tp_degree, 4), | ||
| moe_ep_degree=max(1, tp_degree // 4), | ||
| batch_size=1, | ||
| ctx_batch_size=1, | ||
| tkg_batch_size=1, | ||
| seq_len=seq_len, | ||
| max_context_length=seq_len - 16, | ||
| torch_dtype=DTYPE, | ||
| on_device_sampling_config=OnDeviceSamplingConfig( | ||
| do_sample=False, | ||
| top_k=1, | ||
| ), | ||
| enable_bucketing=False, | ||
| flash_decoding_enabled=False, | ||
| fused_qkv=True, | ||
| sequence_parallel_enabled=False, | ||
| qkv_kernel_enabled=(tp_degree >= 8), | ||
| attn_kernel_enabled=(tp_degree >= 8), | ||
| ) | ||
|
|
||
|
|
||
| def generate( | ||
| model_path: str, | ||
| traced_model_path: str, | ||
| skip_compile: bool = False, | ||
| tp_degree: int = 2, | ||
| seq_len: int = 64, | ||
| ): | ||
| """Compile (if needed) and run Solar Open MoE inference.""" | ||
| if not skip_compile: | ||
| print("=" * 60) | ||
| print("Compiling Solar Open MoE model...") | ||
| print("=" * 60) | ||
|
|
||
| neuron_config = get_neuron_config(tp_degree=tp_degree, seq_len=seq_len) | ||
| config = SolarOpenInferenceConfig( | ||
| neuron_config, | ||
| load_config=load_solar_open_config(model_path), | ||
| ) | ||
|
|
||
| print( | ||
| f" Model config: hidden_size={config.hidden_size}, " | ||
| f"n_routed_experts={config.n_routed_experts}, " | ||
| f"n_shared_experts={config.n_shared_experts}, " | ||
| f"num_experts_per_tok={config.num_experts_per_tok}" | ||
| ) | ||
|
|
||
| model = NeuronSolarOpenForCausalLM(model_path, config) | ||
| model.compile(traced_model_path) | ||
|
|
||
| # Copy model weights and generation config to traced path so load() finds them. | ||
| for fname in ("model.safetensors", "generation_config.json"): | ||
| src = os.path.join(model_path, fname) | ||
| dst = os.path.join(traced_model_path, fname) | ||
| if os.path.exists(src) and not os.path.exists(dst): | ||
| shutil.copy2(src, dst) | ||
| print(f" Copied {fname} to {traced_model_path}") | ||
|
|
||
| # Save tokenizer alongside the traced model for convenience. | ||
| try: | ||
| tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
| tokenizer.save_pretrained(traced_model_path) | ||
| except Exception as e: | ||
| print(f" Warning: could not save tokenizer: {e}") | ||
|
|
||
| print(f" Model compiled and saved to {traced_model_path}") | ||
|
|
||
| # Load compiled model | ||
| print("\n" + "=" * 60) | ||
| print("Loading compiled Solar Open MoE model...") | ||
| print("=" * 60) | ||
| model = NeuronSolarOpenForCausalLM(traced_model_path) | ||
| model.load(traced_model_path) | ||
|
|
||
| # Load tokenizer | ||
| try: | ||
| tokenizer = AutoTokenizer.from_pretrained(traced_model_path) | ||
| except Exception: | ||
| try: | ||
| tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
| except Exception: | ||
| tokenizer = None | ||
|
|
||
| # Generate | ||
| print("\n" + "=" * 60) | ||
| print("Generating outputs...") | ||
| print("=" * 60) | ||
|
|
||
| prompt = "What is the capital of France?" | ||
|
|
||
| if tokenizer is not None: | ||
| inputs = tokenizer([prompt], return_tensors="pt", padding=True) | ||
| input_ids = inputs.input_ids | ||
| attention_mask = inputs.attention_mask | ||
| print(f" Prompt: {prompt!r}") | ||
| else: | ||
| input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) | ||
| attention_mask = torch.ones_like(input_ids) | ||
| print(f" Using dummy input_ids: {input_ids}") | ||
|
|
||
| try: | ||
| generation_config = GenerationConfig.from_pretrained(model_path) | ||
| except Exception: | ||
| generation_config = GenerationConfig( | ||
| max_new_tokens=10, | ||
| do_sample=False, | ||
| top_k=1, | ||
| ) | ||
|
|
||
| generation_model = HuggingFaceGenerationAdapter(model) | ||
| outputs = generation_model.generate( | ||
| input_ids, | ||
| generation_config=generation_config, | ||
| attention_mask=attention_mask, | ||
| max_length=model.config.neuron_config.max_length, | ||
| ) | ||
|
|
||
| print(f" Output token ids: {outputs}") | ||
|
|
||
| if tokenizer is not None: | ||
| decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
| print(" Generated text:") | ||
| for i, text in enumerate(decoded): | ||
| print(f" [{i}]: {text}") | ||
|
|
||
| return outputs | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Solar Open MoE generation demo") | ||
| parser.add_argument( | ||
| "--model-path", | ||
| default=MODEL_PATH, | ||
| help="Path to HF model (local or HuggingFace Hub ID)", | ||
| ) | ||
| parser.add_argument( | ||
| "--traced-model-path", | ||
| default=TRACED_MODEL_PATH, | ||
| help="Path to save/load the compiled Neuron model", | ||
| ) | ||
| parser.add_argument( | ||
| "--skip-compile", | ||
| action="store_true", | ||
| help="Skip compilation; load an existing traced model", | ||
| ) | ||
| parser.add_argument( | ||
| "--tp-degree", | ||
| type=int, | ||
| default=2, | ||
| help="Tensor parallelism degree (use 32 for 100B on trn2.48xlarge)", | ||
| ) | ||
| parser.add_argument( | ||
| "--seq-len", | ||
| type=int, | ||
| default=64, | ||
| help="Maximum sequence length (use 65536 for 100B)", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| generate( | ||
| model_path=args.model_path, | ||
| traced_model_path=args.traced_model_path, | ||
| skip_compile=args.skip_compile, | ||
| tp_degree=args.tp_degree, | ||
| seq_len=args.seq_len, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Solar Open contrib package | ||
| from .modeling_solar_open import ( | ||
| NeuronSolarOpenForCausalLM, | ||
| NeuronSolarOpenModel, | ||
| NeuronSolarOpenDecoderLayer, | ||
| NeuronSolarOpenAttention, | ||
| NeuronSolarOpenRouter, | ||
| SolarOpenInferenceConfig, | ||
| SolarOpenYarnRotaryEmbedding, | ||
| load_solar_open_config, | ||
| convert_solar_open_hf_to_neuron_state_dict, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "NeuronSolarOpenForCausalLM", | ||
| "NeuronSolarOpenModel", | ||
| "NeuronSolarOpenDecoderLayer", | ||
| "NeuronSolarOpenAttention", | ||
| "NeuronSolarOpenRouter", | ||
| "SolarOpenInferenceConfig", | ||
| "SolarOpenYarnRotaryEmbedding", | ||
| "load_solar_open_config", | ||
| "convert_solar_open_hf_to_neuron_state_dict", | ||
| ] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: As you mentioned in the generation code. It is available in transformers (>= 5.0.0).