diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 1e4757cd44..987f4335f7 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -567,12 +567,15 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, proto_contents, torch_tensor, proto_datatype); break; case torch::kBFloat16: { - // Need to convert bfloat16 to uint8_t for storage - auto bfloat16_ptr = torch_tensor.data_ptr(); + // Need to convert bfloat16 to uint8_t for storage. + // Ensure contiguous layout first — data_ptr() reads raw memory + // which would be wrong for NHWC/ChannelsLast tensors. + torch::Tensor contig_tensor = torch_tensor.contiguous(); + auto bfloat16_ptr = contig_tensor.data_ptr(); uint8_t* uint8_ptr = reinterpret_cast(bfloat16_ptr); torch::Tensor uint8_tensor = torch::from_blob(uint8_ptr, - {static_cast(torch_tensor.numel() * + {static_cast(contig_tensor.numel() * sizeof(torch::BFloat16))}, torch::dtype(torch::kUInt8)); data_set_success = set_data_to_contents( @@ -580,12 +583,15 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, break; } case torch::kHalf: { - // Need to convert float16 to uint8_t for storage - auto float16_ptr = torch_tensor.data_ptr(); + // Need to convert float16 to uint8_t for storage. + // Ensure contiguous layout first — data_ptr() reads raw memory + // which would be wrong for NHWC/ChannelsLast tensors. + torch::Tensor contig_tensor = torch_tensor.contiguous(); + auto float16_ptr = contig_tensor.data_ptr(); uint8_t* uint8_ptr = reinterpret_cast(float16_ptr); torch::Tensor uint8_tensor = torch::from_blob( uint8_ptr, - {static_cast(torch_tensor.numel() * sizeof(torch::Half))}, + {static_cast(contig_tensor.numel() * sizeof(torch::Half))}, torch::dtype(torch::kUInt8)); data_set_success = set_data_to_contents( proto_contents, uint8_tensor, proto_datatype);