From 6bd99b687abe0b4a622af844c581bf0b36c663a5 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 29 May 2026 21:31:51 -0700 Subject: [PATCH 1/2] Support latest SGLang integration Signed-off-by: Yubo Wang --- docker/sglang/v0.5.12/Dockerfile | 27 + patches/sglang/v0.5.12/sglang.patch | 998 ++++++++++++++++++++++++ tests/test_mooncake_force_delete.py | 13 + tests/test_sglang_engine_integration.py | 11 +- torchspec/config/mooncake_config.py | 13 +- 5 files changed, 1055 insertions(+), 7 deletions(-) create mode 100644 docker/sglang/v0.5.12/Dockerfile create mode 100644 patches/sglang/v0.5.12/sglang.patch diff --git a/docker/sglang/v0.5.12/Dockerfile b/docker/sglang/v0.5.12/Dockerfile new file mode 100644 index 00000000..2450dfcd --- /dev/null +++ b/docker/sglang/v0.5.12/Dockerfile @@ -0,0 +1,27 @@ +ARG SGLANG_IMAGE=598726163780.dkr.ecr.us-west-2.amazonaws.com/tgl:v0.5.12-cu130-43a42c1bd9-dirty-warm-20260526 +FROM ${SGLANG_IMAGE} AS sglang + +WORKDIR /root/ + +COPY patches/sglang/v0.5.12/sglang.patch /sgl-workspace/ +RUN cd /sgl-workspace/sglang && \ + git apply /sgl-workspace/sglang.patch && \ + rm /sgl-workspace/sglang.patch && \ + cd python && pip install -e . + +COPY . /root/torchspec +RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" + +# TorchSpec's dependency pulls the generic CUDA 12 Mooncake wheel. Replace it +# with the CUDA 13 wheel that matches this image family. +RUN pip uninstall -y mooncake-transfer-engine mooncake-transfer-engine-cuda13 || true && \ + pip install --no-cache-dir --no-deps mooncake-transfer-engine-cuda13==0.3.11.post1 + +RUN chmod 755 /usr/local/lib/python3.12/dist-packages/mooncake/mooncake_master || true +RUN if [ -f /usr/local/lib/python3.12/dist-packages/mooncake/cli.py ]; then \ + sed -i 's/os.chmod(bin_path, 0o755)/pass/' /usr/local/lib/python3.12/dist-packages/mooncake/cli.py; \ + fi + +RUN python3 -c "from sglang.srt.server_args import ServerArgs; from sglang.srt.speculative.spec_training_info import SpecTrainingInfo; assert 'enable_spec_training_mooncake' in ServerArgs.__dataclass_fields__; assert 'enable_aux_hidden_states' in ServerArgs.__dataclass_fields__" + +WORKDIR /root/torchspec diff --git a/patches/sglang/v0.5.12/sglang.patch b/patches/sglang/v0.5.12/sglang.patch new file mode 100644 index 00000000..17342959 --- /dev/null +++ b/patches/sglang/v0.5.12/sglang.patch @@ -0,0 +1,998 @@ +torchspec sglang patch (base: 0.5.12+mm.43a42c1bd9.dirty image source) +--- + python/sglang/srt/entrypoints/engine.py | 8 + + python/sglang/srt/entrypoints/http_server.py | 17 +++ + python/sglang/srt/layers/logits_processor.py | 54 +++++++ + python/sglang/srt/managers/detokenizer_manager.py | 3 + + python/sglang/srt/managers/io_struct.py | 48 ++++++ + python/sglang/srt/managers/schedule_batch.py | 54 ++++++- + python/sglang/srt/managers/scheduler.py | 57 ++++++- + .../managers/scheduler_output_processor_mixin.py | 169 ++++++++++++++++++--- + python/sglang/srt/managers/tokenizer_manager.py | 14 ++ + .../srt/model_executor/forward_batch_info.py | 4 + + python/sglang/srt/model_executor/model_runner.py | 15 ++ + python/sglang/srt/models/qwen3_next.py | 8 + + python/sglang/srt/models/qwen3_next_mtp.py | 3 + + python/sglang/srt/server_args.py | 30 ++++ + .../sglang/srt/speculative/spec_training_info.py | 50 ++++++ + 15 files changed, 514 insertions(+), 20 deletions(-) + +diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py +index f96445c..d80d6b9 100644 +--- a/python/sglang/srt/entrypoints/engine.py ++++ b/python/sglang/srt/entrypoints/engine.py +@@ -346,6 +346,8 @@ class Engine(EngineScoreMixin, EngineBase): + rid: Optional[Union[List[str], str]] = None, + session_params: Optional[Dict] = None, + priority: Optional[int] = None, ++ spec_training_data_id: Optional[Union[List[str], str]] = None, ++ packed_loss_mask: Optional[Union[List[str], str]] = None, + ) -> Union[Dict, Iterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. +@@ -381,6 +383,8 @@ class Engine(EngineScoreMixin, EngineBase): + rid=rid, + session_params=session_params, + priority=priority, ++ spec_training_data_id=spec_training_data_id, ++ packed_loss_mask=packed_loss_mask, + ) + generator = self.tokenizer_manager.generate_request(obj, None) + +@@ -438,6 +442,8 @@ class Engine(EngineScoreMixin, EngineBase): + rid: Optional[Union[List[str], str]] = None, + session_params: Optional[Dict] = None, + priority: Optional[int] = None, ++ spec_training_data_id: Optional[Union[List[str], str]] = None, ++ packed_loss_mask: Optional[Union[List[str], str]] = None, + ) -> Union[Dict, AsyncIterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. +@@ -473,6 +479,8 @@ class Engine(EngineScoreMixin, EngineBase): + rid=rid, + session_params=session_params, + priority=priority, ++ spec_training_data_id=spec_training_data_id, ++ packed_loss_mask=packed_loss_mask, + ) + generator = self.tokenizer_manager.generate_request(obj, None) + +diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py +index 80081fc..d0b86e9 100644 +--- a/python/sglang/srt/entrypoints/http_server.py ++++ b/python/sglang/srt/entrypoints/http_server.py +@@ -734,6 +734,23 @@ async def generate_request(obj: GenerateReqInput, request: Request): + return _create_error_response(e) + + ++@app.api_route("/generate_for_spec_training", methods=["POST", "PUT"]) ++async def generate_for_spec_training(obj: GenerateReqInput, request: Request): ++ """Handle a speculative training data collection request. ++ ++ This endpoint reuses the generate flow but expects spec_training_data_id ++ and packed_loss_mask to be set. ++ """ ++ try: ++ ret = await _global_state.tokenizer_manager.generate_request( ++ obj, request ++ ).__anext__() ++ return ret ++ except ValueError as e: ++ logger.error(f"[http_server] Error: {e}") ++ return _create_error_response(e) ++ ++ + @app.api_route("/encode", methods=["POST", "PUT"]) + async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index e14895e..d9b1099 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -107,6 +107,10 @@ class LogitsProcessorOutput: + + mm_input_embeds: Optional[torch.Tensor] = None + ++ ## Part 6: Spec training - skip sampling and use these fake token ids ++ skip_sampling_next_token_ids: Optional[torch.Tensor] = None ++ last_hidden_states: Optional[torch.Tensor] = None ++ + + @dataclasses.dataclass + class LogitsMetadata: +@@ -151,6 +155,9 @@ class LogitsMetadata: + + mm_input_embeds: Optional[torch.Tensor] = None + ++ # For spec training ++ has_spec_training: bool = False ++ + @classmethod + def from_forward_batch(cls, forward_batch: ForwardBatch): + if ( +@@ -202,6 +209,7 @@ class LogitsMetadata: + global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, + dp_padding_mode=DpPaddingMode.SUM_LEN, + mm_input_embeds=forward_batch.mm_input_embeds, ++ has_spec_training=forward_batch.has_spec_training, + ) + + def compute_dp_attention_metadata(self): +@@ -311,6 +319,11 @@ class LogitsProcessor(nn.Module): + if logits_metadata.forward_mode.is_dllm_extend(): + return self._get_dllm_logits(hidden_states, lm_head, logits_metadata) + ++ last_hidden_states = None ++ if logits_metadata.has_spec_training: ++ assert hidden_states_before_norm is None ++ last_hidden_states = hidden_states ++ + # Get the last hidden states and last logits for the next token prediction + ( + pruned_states, +@@ -338,6 +351,46 @@ class LogitsProcessor(nn.Module): + ) + del hidden_states + ++ # TODO(ywang): Support finegrained control over requests instead of ++ # forcing all requests to be spec training requests ++ ++ # For offline spec training (prefill-only, no decode), skip sampling ++ # and return fake EOS. In decode mode with EAGLE, the eagle_worker ++ # sets has_spec_training=False so this block is not triggered. ++ if ( ++ logits_metadata.has_spec_training ++ and logits_metadata.forward_mode.is_extend() ++ ): ++ if logits_metadata.extend_seq_lens is not None: ++ num_seqs = len(logits_metadata.extend_seq_lens) ++ elif input_ids is not None: ++ num_seqs = input_ids.shape[0] ++ else: ++ num_seqs = 1 ++ eos_token_id = getattr(self.config, "eos_token_id", 0) ++ if isinstance(eos_token_id, list): ++ eos_token_id = eos_token_id[0] ++ if input_ids is not None: ++ device = input_ids.device ++ elif isinstance(last_hidden_states, torch.Tensor): ++ device = last_hidden_states.device ++ elif isinstance(last_hidden_states, tuple): ++ device = last_hidden_states[0].device ++ else: ++ device = lm_head.weight.device ++ fake_next_token_ids = torch.full( ++ (num_seqs,), ++ eos_token_id, ++ dtype=torch.long, ++ device=device, ++ ) ++ return LogitsProcessorOutput( ++ next_token_logits=None, ++ hidden_states=hidden_states_to_store, ++ last_hidden_states=last_hidden_states, ++ skip_sampling_next_token_ids=fake_next_token_ids, ++ ) ++ + if not logits_metadata.extend_return_logprob: + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) +@@ -416,6 +469,7 @@ class LogitsProcessor(nn.Module): + logits_metadata.forward_mode.is_decode_or_idle() + or logits_metadata.forward_mode.is_target_verify() + or logits_metadata.forward_mode.is_draft_extend_v2() ++ or logits_metadata.has_spec_training + ): + pruned_states = hidden_states + pruned_states_before_norm = hidden_states_before_norm +diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py +index a4547bf..a7ed053 100644 +--- a/python/sglang/srt/managers/detokenizer_manager.py ++++ b/python/sglang/srt/managers/detokenizer_manager.py +@@ -388,6 +388,9 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + load=recv_obj.load, + dp_ranks=recv_obj.dp_ranks, + time_stats=recv_obj.time_stats, ++ spec_training_data_ids=recv_obj.spec_training_data_ids, ++ packed_loss_masks=recv_obj.packed_loss_masks, ++ spec_training_mooncake_store_keys=recv_obj.spec_training_mooncake_store_keys, + ) + + def handle_freeze_gc_req(self, recv_req: FreezeGCReq): +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index 293335f..282803a 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -260,6 +260,13 @@ class GenerateReqInput(BaseReq): + # Batch-level: List[List[int]] (one per request). After __getitem__: List[int]. + multi_item_delimiter_indices: Optional[Union[List[List[int]], List[int]]] = None + ++ # Speculative training fields ++ spec_training_data_id: Optional[Union[List[str], str]] = None ++ packed_loss_mask: Optional[Union[List[str], str]] = None ++ ++ def is_spec_training_request(self) -> bool: ++ return self.spec_training_data_id is not None ++ + def contains_mm_input(self) -> bool: + return ( + has_valid_data(self.image_data) +@@ -405,6 +412,7 @@ class GenerateReqInput(BaseReq): + self._normalize_logprob_params(num) + self._normalize_custom_logit_processor(num) + self._normalize_bootstrap_params(num) ++ self._normalize_spec_training_params(num) + + def _expand_inputs(self, num): + """Expand the main inputs (text, input_ids, input_embeds) for parallel sampling.""" +@@ -607,6 +615,24 @@ class GenerateReqInput(BaseReq): + elif isinstance(self.bootstrap_pair_key, list): + self.bootstrap_pair_key = self.bootstrap_pair_key * self.parallel_sample_num + ++ def _normalize_spec_training_params(self, num): ++ """Normalize speculative training parameters for batch processing.""" ++ if self.spec_training_data_id is None: ++ self.spec_training_data_id = [None] * num ++ elif not isinstance(self.spec_training_data_id, list): ++ self.spec_training_data_id = [self.spec_training_data_id] * num ++ elif isinstance(self.spec_training_data_id, list): ++ self.spec_training_data_id = ( ++ self.spec_training_data_id * self.parallel_sample_num ++ ) ++ ++ if self.packed_loss_mask is None: ++ self.packed_loss_mask = [None] * num ++ elif not isinstance(self.packed_loss_mask, list): ++ self.packed_loss_mask = [self.packed_loss_mask] * num ++ elif isinstance(self.packed_loss_mask, list): ++ self.packed_loss_mask = self.packed_loss_mask * self.parallel_sample_num ++ + def _validate_session_params(self): + """Validate that session parameters are properly formatted.""" + if self.session_params is not None: +@@ -697,6 +723,14 @@ class GenerateReqInput(BaseReq): + external_trace_header=self.external_trace_header, + http_worker_ipc=self.http_worker_ipc, + received_time=self.received_time, ++ spec_training_data_id=( ++ self.spec_training_data_id[i] ++ if self.spec_training_data_id is not None ++ else None ++ ), ++ packed_loss_mask=( ++ self.packed_loss_mask[i] if self.packed_loss_mask is not None else None ++ ), + multi_item_delimiter_indices=( + self.multi_item_delimiter_indices[i] + if self.multi_item_delimiter_indices is not None +@@ -796,6 +830,10 @@ class TokenizedGenerateReqInput(BaseReq): + # Pre-computed delimiter indices for multi-item scoring + multi_item_delimiter_indices: Optional[List[int]] = None + ++ # Speculative training fields ++ spec_training_data_id: Optional[str] = None ++ packed_loss_mask: Optional[str] = None ++ + # For observability + time_stats: Optional[Union[APIServerReqTimeStats, DPControllerReqTimeStats]] = None + +@@ -1137,6 +1175,11 @@ class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): + # DP rank of the scheduler that processed each request + dp_ranks: Optional[List[int]] = None + ++ # Speculative training fields ++ spec_training_data_ids: Optional[List[str]] = None ++ packed_loss_masks: Optional[List[str]] = None ++ spec_training_mooncake_store_keys: Optional[List[List[str]]] = None ++ + # For observability + time_stats: Optional[List[SchedulerReqTimeStats]] = None + +@@ -1203,6 +1246,11 @@ class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): + # DP rank of the scheduler that processed each request + dp_ranks: Optional[List[int]] = None + ++ # Speculative training fields ++ spec_training_data_ids: Optional[List[str]] = None ++ packed_loss_masks: Optional[List[str]] = None ++ spec_training_mooncake_store_keys: Optional[List[List[str]]] = None ++ + # For observability + time_stats: Optional[List[SchedulerReqTimeStats]] = None + +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index ab8cbc7..e2502f0 100755 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -92,6 +92,7 @@ from sglang.srt.observability.req_time_stats import ( + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + from sglang.srt.sampling.sampling_params import SamplingParams + from sglang.srt.server_args import ServerArgs, get_global_server_args ++from sglang.srt.speculative.spec_training_info import SpecTrainingInfo + from sglang.srt.utils import flatten_nested_list + from sglang.srt.utils.cuda_ipc_transport_utils import CudaIpcTensorTransportProxy + +@@ -620,6 +621,8 @@ class Req(ReqDllmMixin): + ] = None, + return_pooled_hidden_states: bool = False, + multi_item_delimiter_indices: Optional[List[int]] = None, ++ spec_training_data_id: Optional[str] = None, ++ packed_loss_mask: Optional[str] = None, + ): + # Input and output info + self.rid = rid +@@ -662,6 +665,11 @@ class Req(ReqDllmMixin): + # For multi-http worker + self.http_worker_ipc = http_worker_ipc + ++ # Spec training fields ++ self.spec_training_data_id = spec_training_data_id ++ self.packed_loss_mask = packed_loss_mask ++ self.spec_training_mooncake_store_keys: List[str] = [] ++ + # Require reasoning for the request + self.require_reasoning = require_reasoning + +@@ -1523,6 +1531,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + # HiSparse + hisparse_coordinator: Optional[HiSparseCoordinator] = None + ++ # Spec Training ++ spec_training_info: Optional[SpecTrainingInfo] = None ++ + @classmethod + def init_new( + cls, +@@ -1542,6 +1553,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): + is_hybrid_swa = True + ++ spec_training_info = SpecTrainingInfo() ++ for req in reqs: ++ if req.spec_training_data_id is not None: ++ spec_training_info.add_request( ++ rid=req.rid, ++ data_id=req.spec_training_data_id, ++ packed_loss_mask=req.packed_loss_mask, ++ ) ++ + batch = cls( + reqs=reqs, + req_to_token_pool=req_to_token_pool, +@@ -1561,6 +1581,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + is_prefill_only=all(req.is_prefill_only for req in reqs), + chunked_req=chunked_req, + dllm_config=dllm_config, ++ spec_training_info=( ++ spec_training_info if not spec_training_info.is_empty() else None ++ ), + ) + return batch + +@@ -2481,6 +2504,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + has_been_filtered=has_been_filtered, + ) + ++ if self.spec_training_info is not None: ++ kept_rids = {req.rid for req in self.reqs} ++ rids_to_remove = [ ++ rid for rid in self.spec_training_info.data_ids if rid not in kept_rids ++ ] ++ for rid in rids_to_remove: ++ self.spec_training_info.remove_request(rid) ++ if self.spec_training_info.is_empty(): ++ self.spec_training_info = None ++ + def merge_batch(self, other: "ScheduleBatch"): + # In the regular scheduler path: + # 1) self is always prefill, whose seq_lens is not a future +@@ -2535,6 +2568,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + if self.spec_info: + self.spec_info.merge_batch(other.spec_info) + ++ if self.spec_training_info is not None or other.spec_training_info is not None: ++ if self.spec_training_info is None: ++ self.spec_training_info = SpecTrainingInfo() ++ if other.spec_training_info is not None: ++ self.spec_training_info.data_ids.update( ++ other.spec_training_info.data_ids ++ ) ++ self.spec_training_info.packed_loss_masks.update( ++ other.spec_training_info.packed_loss_masks ++ ) ++ self.spec_training_info.mooncake_store_keys.update( ++ other.spec_training_info.mooncake_store_keys ++ ) ++ + def get_model_worker_batch( + self, seq_lens_cpu_cache: Optional[torch.Tensor] = None + ) -> ModelWorkerBatch: +@@ -2595,7 +2642,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + hicache_consumer_index=self.hicache_consumer_index, + capture_hidden_mode=( + CaptureHiddenMode.FULL +- if self.return_hidden_states ++ if self.return_hidden_states or self.spec_training_info is not None + else ( + getattr( + self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL +@@ -2616,6 +2663,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + mamba_track_indices=self.mamba_track_indices, + mamba_track_mask=self.mamba_track_mask, + mamba_track_seqlens=self.mamba_track_seqlens, ++ has_spec_training=self.spec_training_info is not None, + ) + + def copy(self): +@@ -2647,6 +2695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + prefill_stats=self.prefill_stats, + fpm_start_time=self.fpm_start_time, + forward_iter=self.forward_iter, ++ spec_training_info=self.spec_training_info, + ) + + def maybe_evict_swa(self): +@@ -2851,3 +2900,6 @@ class ModelWorkerBatch: + mamba_track_indices: Optional[torch.Tensor] = None # shape: [b], int64 + mamba_track_mask: Optional[torch.Tensor] = None # shape: [b], bool + mamba_track_seqlens: Optional[torch.Tensor] = None # shape: [b], int64 ++ ++ # For spec training ++ has_spec_training: bool = False +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index a31a0e6..04456bb 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -433,6 +433,26 @@ class Scheduler( + # Init mamba backend + self.init_mamba_backend() + ++ # Start mooncake store init in background (overlaps with model loading) ++ self._mooncake_init_thread = None ++ self._mooncake_init_error = None ++ self.eagle_mooncake_store = None ++ if self.server_args.enable_spec_training_mooncake and self.attn_tp_rank == 0: ++ import threading ++ ++ mooncake_device = torch.device(f"cuda:{self.gpu_id}") ++ ++ def _init_mooncake(): ++ try: ++ self.init_eagle_mooncake_store(device=mooncake_device) ++ except Exception as e: ++ self._mooncake_init_error = e ++ ++ self._mooncake_init_thread = threading.Thread( ++ target=_init_mooncake, daemon=True ++ ) ++ self._mooncake_init_thread.start() ++ + # Launch a model worker and draft model worker if using speculative decoding + self.init_model_worker() + self.install_device_timer_on_runners() +@@ -497,6 +517,12 @@ class Scheduler( + # Init the grammar backend for constrained generation + self.grammar_manager = GrammarManager(self) + ++ # Wait for background mooncake store init to complete ++ if self._mooncake_init_thread is not None: ++ self._mooncake_init_thread.join() ++ if self._mooncake_init_error is not None: ++ raise self._mooncake_init_error ++ + self.is_initializing = False + + def init_zbal_on_npu(self): +@@ -1064,6 +1090,25 @@ class Scheduler( + self.forward_sleep_time = None + self._engine_paused = False + ++ def init_eagle_mooncake_store(self, device=None): ++ if self.server_args.enable_spec_training_mooncake: ++ try: ++ from torchspec.transfer.mooncake import ( ++ EagleMooncakeStore, ++ MooncakeConfig, ++ ) ++ ++ config = MooncakeConfig.from_env() ++ store = EagleMooncakeStore(config) ++ store.setup(device=device or self.device) ++ store.warmup_rdma() ++ self.eagle_mooncake_store = store ++ logger.info("EagleMooncakeStore initialized for spec training") ++ except ImportError: ++ logger.warning( ++ "torchspec.mooncake not found. Spec training mooncake store disabled." ++ ) ++ + def init_chunked_prefill(self): + self.chunked_prefill_size = self.server_args.chunked_prefill_size + uses_transformers_backend = ( +@@ -2051,6 +2096,8 @@ class Scheduler( + dllm_config=self.dllm_config, + time_stats=recv_req.time_stats, + multi_item_delimiter_indices=recv_req.multi_item_delimiter_indices, ++ spec_training_data_id=recv_req.spec_training_data_id, ++ packed_loss_mask=recv_req.packed_loss_mask, + ) + req.tokenizer = self.tokenizer + +@@ -3052,7 +3099,10 @@ class Scheduler( + self.future_map.store_to_map(future_indices, batch_result) + batch_result.copy_to_cpu( + return_logprob=batch.return_logprob, +- return_hidden_states=batch.return_hidden_states, ++ return_hidden_states=( ++ batch.return_hidden_states ++ and batch.spec_training_info is None ++ ), + ) + else: + batch_result.future_indices = future_indices +@@ -3160,7 +3210,10 @@ class Scheduler( + self.future_map.store_to_map(batch_result.future_indices, batch_result) + batch_result.copy_to_cpu( + return_logprob=self.cur_batch.return_logprob, +- return_hidden_states=self.cur_batch.return_hidden_states, ++ return_hidden_states=( ++ self.cur_batch.return_hidden_states ++ and self.cur_batch.spec_training_info is None ++ ), + ) + + # Release the closure and large GPU tensors that are no longer needed. +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index ae6f732..2c12870 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -283,21 +283,22 @@ class SchedulerOutputProcessorMixin: + ) + logprob_pt += num_input_logprobs + +- if ( +- req.return_hidden_states +- and logits_output.hidden_states is not None +- ): +- req.hidden_states.append( +- logits_output.hidden_states[ +- hidden_state_offset : ( +- hidden_state_offset := hidden_state_offset +- + len(req.origin_input_ids) +- ) +- ] +- .cpu() +- .clone() +- .tolist() ++ should_process_hidden_states = ( ++ logits_output.hidden_states is not None ++ and ( ++ req.return_hidden_states ++ or ( ++ req.spec_training_data_id is not None ++ and self.attn_tp_rank == 0 ++ ) + ) ++ ) ++ if should_process_hidden_states: ++ self._process_hidden_states_for_req( ++ req, batch, logits_output, hidden_state_offset, ++ copy_done_event=result.copy_done, ++ ) ++ hidden_state_offset += len(req.origin_input_ids) + + if req.grammar is not None: + # FIXME: this try-except block is for handling unexpected xgrammar issue. +@@ -993,6 +994,106 @@ class SchedulerOutputProcessorMixin: + if req.input_token_ids_logprobs_idx is None: + req.input_token_ids_logprobs_idx = [] + ++ def _process_hidden_states_for_req( ++ self: Scheduler, ++ req: Req, ++ batch: ScheduleBatch, ++ logits_output: LogitsProcessorOutput, ++ hidden_state_offset: int, ++ copy_done_event=None, ++ ): ++ """Process hidden states during prefill for spec training or return_hidden_states.""" ++ seq_len = len(req.origin_input_ids) ++ req_hidden_states = logits_output.hidden_states[ ++ hidden_state_offset : hidden_state_offset + seq_len ++ ] ++ ++ if ( ++ batch.spec_training_info is not None ++ and batch.spec_training_info.has_request(req.rid) ++ and self.eagle_mooncake_store is not None ++ ): ++ self._send_hidden_states_to_mooncake( ++ req, batch, req_hidden_states, logits_output, hidden_state_offset, ++ copy_done_event=copy_done_event, ++ ) ++ else: ++ if copy_done_event is not None: ++ copy_done_event.synchronize() ++ req.hidden_states.append( ++ req_hidden_states.cpu().clone().tolist() ++ ) ++ ++ def _put_spec_training_mooncake( ++ self: Scheduler, ++ key: str, ++ hidden_states: torch.Tensor, ++ input_ids: torch.Tensor, ++ last_hidden_states: Optional[torch.Tensor], ++ ): ++ import os ++ ++ if os.getenv("TORCHSPEC_USP_SHARDED_MOONCAKE") == "1": ++ max_seq_raw = os.environ.get("TORCHSPEC_USP_MAX_SEQ_LENGTH") ++ self.eagle_mooncake_store.put_usp_shards( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ target=None, ++ sp_size=int(os.environ["TORCHSPEC_USP_SP_SIZE"]), ++ sp_ring_size=int(os.environ.get("TORCHSPEC_USP_RING_SIZE", "1")), ++ ttt_length=int(os.environ.get("TORCHSPEC_USP_TTT_LENGTH", "1")), ++ max_seq_length=int(max_seq_raw) if max_seq_raw else None, ++ ) ++ return ++ ++ self.eagle_mooncake_store.put( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ ) ++ ++ def _send_hidden_states_to_mooncake( ++ self: Scheduler, ++ req: Req, ++ batch: ScheduleBatch, ++ hidden_states: torch.Tensor, ++ logits_output: LogitsProcessorOutput, ++ hidden_state_offset: int, ++ copy_done_event=None, ++ ): ++ import uuid ++ ++ data_id = batch.spec_training_info.data_ids.get(req.rid, req.rid) ++ key = f"{data_id}_{uuid.uuid4().hex[:8]}" ++ ++ seq_len = hidden_states.shape[0] ++ input_ids = torch.tensor( ++ req.origin_input_ids, dtype=torch.long, device=hidden_states.device ++ ) ++ ++ last_hidden_states = None ++ store_lhs = getattr(self.server_args, "spec_training_store_last_hidden_states", True) ++ if store_lhs and logits_output.last_hidden_states is not None: ++ last_hidden_states = logits_output.last_hidden_states[ ++ hidden_state_offset : hidden_state_offset + seq_len ++ ] ++ ++ if hidden_states.is_cuda and copy_done_event is not None: ++ torch.cuda.current_stream().wait_event(copy_done_event) ++ ++ self._put_spec_training_mooncake( ++ key=key, ++ hidden_states=hidden_states, ++ input_ids=input_ids, ++ last_hidden_states=last_hidden_states, ++ ) ++ ++ req.spec_training_mooncake_store_keys.append(key) ++ batch.spec_training_info.mooncake_store_keys[data_id].append(key) ++ + def stream_output( + self: Scheduler, + reqs: List[Req], +@@ -1055,6 +1156,18 @@ class SchedulerOutputProcessorMixin: + indexer_topk = None + customized_info = {} + ++ if self.attn_tp_rank == 0 and self.eagle_mooncake_store is not None: ++ self.eagle_mooncake_store.flush() ++ ++ if self.attn_tp_rank == 0: ++ spec_training_data_ids = [] ++ packed_loss_masks = [] ++ spec_training_mooncake_store_keys = [] ++ else: ++ spec_training_data_ids = None ++ packed_loss_masks = None ++ spec_training_mooncake_store_keys = None ++ + time_stats = [] + + if return_logprob: +@@ -1161,6 +1274,13 @@ class SchedulerOutputProcessorMixin: + req.spec_correct_drafts_histogram + ) + ++ if spec_training_data_ids is not None: ++ spec_training_data_ids.append(req.spec_training_data_id) ++ packed_loss_masks.append(req.packed_loss_mask) ++ spec_training_mooncake_store_keys.append( ++ req.spec_training_mooncake_store_keys ++ ) ++ + if return_logprob: + if ( + req.return_logprob +@@ -1231,9 +1351,15 @@ class SchedulerOutputProcessorMixin: + output_token_ids_logprobs_idx.append([]) + + if req.return_hidden_states: +- if output_hidden_states is None: +- output_hidden_states = [] +- output_hidden_states.append(req.hidden_states) ++ uses_mooncake = ( ++ self.attn_tp_rank == 0 ++ and req.spec_training_data_id is not None ++ and self.eagle_mooncake_store is not None ++ ) ++ if not uses_mooncake: ++ if output_hidden_states is None: ++ output_hidden_states = [] ++ output_hidden_states.append(req.hidden_states) + if req.return_routed_experts: + if routed_experts is None: + routed_experts = [] +@@ -1305,6 +1431,15 @@ class SchedulerOutputProcessorMixin: + retraction_counts=retraction_counts, + load=load, + dp_ranks=dp_ranks, ++ spec_training_data_ids=( ++ spec_training_data_ids if spec_training_data_ids else None ++ ), ++ packed_loss_masks=packed_loss_masks if packed_loss_masks else None, ++ spec_training_mooncake_store_keys=( ++ spec_training_mooncake_store_keys ++ if spec_training_mooncake_store_keys ++ else None ++ ), + ) + ) + +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index 6375a1d..8bc57be 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -1036,6 +1036,8 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): + need_wait_for_mm_inputs=obj.need_wait_for_mm_inputs, + num_items_assigned=obj.num_items_assigned, + multi_item_delimiter_indices=obj.multi_item_delimiter_indices, ++ spec_training_data_id=obj.spec_training_data_id, ++ packed_loss_mask=obj.packed_loss_mask, + ) + elif isinstance(obj, EmbeddingReqInput): + # Resolve unresolved embed overrides now that input_ids are available +@@ -1743,6 +1745,18 @@ class TokenizerManager(TokenizerControlMixin, TokenizerManagerScoreMixin): + if getattr(recv_obj, "dp_ranks", None): + meta_info["dp_rank"] = recv_obj.dp_ranks[i] + ++ if ( ++ hasattr(recv_obj, "spec_training_data_ids") ++ and recv_obj.spec_training_data_ids is not None ++ and i < len(recv_obj.spec_training_data_ids) ++ and recv_obj.spec_training_data_ids[i] is not None ++ ): ++ meta_info["spec_training_data_id"] = recv_obj.spec_training_data_ids[i] ++ meta_info["packed_loss_mask"] = recv_obj.packed_loss_masks[i] ++ meta_info["spec_training_mooncake_store_keys"] = ( ++ recv_obj.spec_training_mooncake_store_keys[i] ++ ) ++ + state.finished = recv_obj.finished_reasons[i] is not None + if isinstance(recv_obj, BatchStrOutput): + # Not all request types have `stream` (e.g., EmbeddingReqInput). Default to non-streaming. +diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py +index 1671c5e..6950f53 100644 +--- a/python/sglang/srt/model_executor/forward_batch_info.py ++++ b/python/sglang/srt/model_executor/forward_batch_info.py +@@ -440,6 +440,9 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): + # For dumper: request IDs for cross-step sequence tracking + rids: Optional[List[str]] = None + ++ # For spec training ++ has_spec_training: bool = False ++ + @classmethod + def init_new( + cls, +@@ -490,6 +493,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): + return_hidden_states_before_norm=batch.return_hidden_states_before_norm, + return_pooled_hidden_states=batch.return_pooled_hidden_states, + rids=[req.rid for req in batch.reqs], ++ has_spec_training=batch.has_spec_training, + ) + device = model_runner.device + +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index ff937dd..bdf14e7 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -440,6 +440,13 @@ class ModelRunner(ModelRunnerKVCacheMixin): + # if there is no aux layer, set to None + self.eagle_aux_hidden_state_layer_ids = None + ++ if self.server_args.enable_aux_hidden_states: ++ self.eagle_use_aux_hidden_state = True ++ if self.server_args.aux_hidden_state_layer_ids is not None: ++ self.eagle_aux_hidden_state_layer_ids = ( ++ self.server_args.aux_hidden_state_layer_ids ++ ) ++ + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, +@@ -901,6 +908,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): + include aux hidden state output paths. + """ + if self.eagle_use_aux_hidden_state: ++ if not hasattr(self.model, "set_eagle3_layers_to_capture"): ++ raise RuntimeError( ++ "Aux hidden state capture is not supported by this model." ++ ) + self.model.set_eagle3_layers_to_capture( + self.eagle_aux_hidden_state_layer_ids + ) +@@ -3499,6 +3510,10 @@ class ModelRunner(ModelRunnerKVCacheMixin): + """ + self._preprocess_logits(logits_output, forward_batch.sampling_info) + ++ # Spec training: skip sampling and return fake EOS tokens. ++ if logits_output.skip_sampling_next_token_ids is not None: ++ return logits_output.skip_sampling_next_token_ids ++ + # Sample the next tokens + next_token_ids = self.sampler( + logits_output, +diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py +index 432b9fb..9998f40 100644 +--- a/python/sglang/srt/models/qwen3_next.py ++++ b/python/sglang/srt/models/qwen3_next.py +@@ -966,6 +966,7 @@ class Qwen3NextForCausalLM(nn.Module): + self.logits_processor = LogitsProcessor(config) + # For EAGLE3 support + self.capture_aux_hidden_states = False ++ self.hot_token_id = None + + self._routed_experts_weights_of_layer = LazyValue( + lambda: { +@@ -1059,6 +1060,13 @@ class Qwen3NextForCausalLM(nn.Module): + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + ++ if "d2t" in name: ++ self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0]) ++ continue ++ ++ if "t2d" in name: ++ continue ++ + if is_mtp: + + if "mtp" not in name: +diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py +index 5f0dcb5..d9785e2 100644 +--- a/python/sglang/srt/models/qwen3_next_mtp.py ++++ b/python/sglang/srt/models/qwen3_next_mtp.py +@@ -84,6 +84,9 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): + ) + self.logits_processor = LogitsProcessor(config) + ++ self.capture_aux_hidden_states = False ++ self.hot_token_id = None ++ + @torch.no_grad() + def forward( + self, +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 0492d40..3808c73 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -606,6 +606,12 @@ class ServerArgs: + speculative_ngram_external_corpus_max_tokens: int = 10000000 + enable_multi_layer_eagle: bool = False + ++ # Spec training (for speculative decoding model training) ++ enable_spec_training_mooncake: bool = False ++ enable_aux_hidden_states: bool = False ++ aux_hidden_state_layer_ids: Optional[List[int]] = None ++ spec_training_store_last_hidden_states: bool = True ++ + # Expert parallelism + ep_size: int = 1 + moe_a2a_backend: Literal[ +@@ -5942,6 +5948,30 @@ class ServerArgs: + help="Enable multi-layer Eagle speculative decoding.", + ) + ++ # Spec training (for speculative decoding model training) ++ parser.add_argument( ++ "--enable-spec-training-mooncake", ++ action="store_true", ++ help="Enable EagleMooncakeStore for spec training hidden state transfer.", ++ ) ++ parser.add_argument( ++ "--enable-aux-hidden-states", ++ action="store_true", ++ help="Enable capturing auxiliary hidden states for supported models.", ++ ) ++ parser.add_argument( ++ "--aux-hidden-state-layer-ids", ++ type=int, ++ nargs="+", ++ help="Layer IDs to capture as auxiliary hidden states. If omitted, model defaults are used.", ++ ) ++ parser.add_argument( ++ "--spec-training-store-last-hidden-states", ++ action=argparse.BooleanOptionalAction, ++ default=True, ++ help="Whether to store last hidden states for spec training requests.", ++ ) ++ + # Expert parallelism + parser.add_argument( + "--expert-parallel-size", +diff --git a/python/sglang/srt/speculative/spec_training_info.py b/python/sglang/srt/speculative/spec_training_info.py +new file mode 100644 +index 0000000..24af14b +--- /dev/null ++++ b/python/sglang/srt/speculative/spec_training_info.py +@@ -0,0 +1,50 @@ ++from dataclasses import dataclass, field ++from typing import Dict, List, Optional ++ ++ ++@dataclass ++class SpecTrainingInfo: ++ """Tracks spec training info for requests in a batch. ++ ++ Keys: ++ - data_ids: rid -> data_id mapping ++ - packed_loss_masks: data_id -> packed_loss_mask string ++ - mooncake_store_keys: data_id -> list of keys ++ """ ++ ++ data_ids: Dict[str, str] = field(default_factory=dict) ++ packed_loss_masks: Dict[str, str] = field(default_factory=dict) ++ mooncake_store_keys: Dict[str, List[str]] = field(default_factory=dict) ++ ++ def add_request( ++ self, ++ rid: str, ++ data_id: Optional[str], ++ packed_loss_mask: Optional[str], ++ ): ++ """Add spec training info for a request if it's a spec training request.""" ++ if data_id is not None: ++ self.data_ids[rid] = data_id ++ self.packed_loss_masks[data_id] = packed_loss_mask ++ if data_id not in self.mooncake_store_keys: ++ self.mooncake_store_keys[data_id] = [] ++ ++ def has_request(self, rid: str) -> bool: ++ return rid in self.data_ids ++ ++ def set_mooncake_store_keys(self, data_id: str, keys: List[str]): ++ if data_id in self.mooncake_store_keys: ++ self.mooncake_store_keys[data_id] = keys ++ ++ def remove_request(self, rid: str): ++ data_id = self.data_ids.pop(rid, None) ++ if data_id is not None: ++ remaining_rids_with_data_id = [ ++ r for r, d in self.data_ids.items() if d == data_id ++ ] ++ if not remaining_rids_with_data_id: ++ self.packed_loss_masks.pop(data_id, None) ++ self.mooncake_store_keys.pop(data_id, None) ++ ++ def is_empty(self) -> bool: ++ return len(self.data_ids) == 0 diff --git a/tests/test_mooncake_force_delete.py b/tests/test_mooncake_force_delete.py index abb4907a..caf2c97e 100644 --- a/tests/test_mooncake_force_delete.py +++ b/tests/test_mooncake_force_delete.py @@ -69,6 +69,19 @@ def test_enable_hard_pin_default_off(self): class TestMooncakeEnvDefaults: + def test_from_env_accepts_kubernetes_service_port_urls(self): + env = { + "MOONCAKE_MASTER_HOST": "10.0.0.7", + "MOONCAKE_MASTER_PORT": "tcp://10.0.0.7:51135", + "MOONCAKE_METADATA_PORT": "tcp://10.0.0.7:8763", + } + + with patch.dict(os.environ, env, clear=True): + config = MooncakeConfig.from_env() + + assert config.master_server_address == "10.0.0.7:51135" + assert config.metadata_server == "http://10.0.0.7:8763/metadata" + def test_tcp_memcpy_default_is_applied_by_export_env(self): config = MooncakeConfig(protocol="tcp") diff --git a/tests/test_sglang_engine_integration.py b/tests/test_sglang_engine_integration.py index cad67203..a2f98f90 100644 --- a/tests/test_sglang_engine_integration.py +++ b/tests/test_sglang_engine_integration.py @@ -35,7 +35,7 @@ from transformers import AutoConfig # noqa: E402 # --------------------------------------------------------------------------- -# Mooncake env setup (must happen before any sglang import) +# Mooncake env setup # --------------------------------------------------------------------------- try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -286,13 +286,14 @@ def main(): } torch.save(meta, dump_dir / "sglang_meta.pt") - # Import mooncake before creating sglang engine — sglang's subprocess - # forking can interfere with the import chain through torchspec.config.__init__ + engine = create_engine(args.model, args.tp, aux_layer_ids) + + # Import Mooncake after SGLang. In the CUDA 13 + mooncake-transfer-engine + # stack, importing mooncake.store before SGLang can segfault while torch's + # pybind dispatch bindings initialize. from torchspec.config.mooncake_config import MooncakeConfig from torchspec.transfer.mooncake.eagle_store import EagleMooncakeStore - engine = create_engine(args.model, args.tp, aux_layer_ids) - mooncake_config = MooncakeConfig.from_env() mooncake_store = EagleMooncakeStore(mooncake_config) mooncake_store.setup(device="cuda") diff --git a/torchspec/config/mooncake_config.py b/torchspec/config/mooncake_config.py index 9b1309f5..6a258aa7 100644 --- a/torchspec/config/mooncake_config.py +++ b/torchspec/config/mooncake_config.py @@ -21,6 +21,7 @@ import os from dataclasses import dataclass from typing import Tuple +from urllib.parse import urlparse from torchspec.transfer.mooncake.helpers import calculate_eagle3_buffer_size @@ -203,8 +204,8 @@ def apply_env_defaults(self) -> None: def from_env(cls) -> "MooncakeConfig": """Create config from environment variables.""" master_host = os.getenv("MOONCAKE_MASTER_HOST", "localhost") - master_port = os.getenv("MOONCAKE_MASTER_PORT", "50051") - metadata_port = os.getenv("MOONCAKE_METADATA_PORT", "8090") + master_port = cls._env_port("MOONCAKE_MASTER_PORT", "50051") + metadata_port = cls._env_port("MOONCAKE_METADATA_PORT", "8090") store_full_wait_seconds = float(os.getenv("MOONCAKE_STORE_FULL_WAIT_SECONDS", "0.5")) store_full_log_interval_seconds = float( os.getenv("MOONCAKE_STORE_FULL_LOG_INTERVAL_SECONDS", "5.0") @@ -251,6 +252,14 @@ def from_env(cls) -> "MooncakeConfig": enable_hard_pin=os.getenv("MOONCAKE_ENABLE_HARD_PIN", "0") == "1", ) + @staticmethod + def _env_port(name: str, default: str) -> str: + value = os.getenv(name, default) + parsed = urlparse(value) + if parsed.scheme and parsed.port is not None: + return str(parsed.port) + return value.rsplit(":", 1)[-1] if ":" in value else value + @classmethod def from_master_address( cls, From 1c022e4f968edc7a2b93cbe849cb5dfac7cfd675 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 29 May 2026 21:38:13 -0700 Subject: [PATCH 2/2] Clean up CUDA13 Mooncake install Signed-off-by: Yubo Wang --- docker/sglang/v0.5.12/Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docker/sglang/v0.5.12/Dockerfile b/docker/sglang/v0.5.12/Dockerfile index 2450dfcd..59ffc0dc 100644 --- a/docker/sglang/v0.5.12/Dockerfile +++ b/docker/sglang/v0.5.12/Dockerfile @@ -14,8 +14,9 @@ RUN cd /root/torchspec && pip install --no-cache-dir -e ".[fa]" # TorchSpec's dependency pulls the generic CUDA 12 Mooncake wheel. Replace it # with the CUDA 13 wheel that matches this image family. -RUN pip uninstall -y mooncake-transfer-engine mooncake-transfer-engine-cuda13 || true && \ - pip install --no-cache-dir --no-deps mooncake-transfer-engine-cuda13==0.3.11.post1 +RUN pip uninstall -y mooncake-transfer-engine || true && \ + pip install --no-cache-dir --no-deps --force-reinstall \ + mooncake-transfer-engine-cuda13==0.3.11.post1 RUN chmod 755 /usr/local/lib/python3.12/dist-packages/mooncake/mooncake_master || true RUN if [ -f /usr/local/lib/python3.12/dist-packages/mooncake/cli.py ]; then \