From 42b76460b774e22066c2a62bdcc0e4c441d0aeef Mon Sep 17 00:00:00 2001 From: a120092009 Date: Mon, 18 May 2026 10:51:33 +0800 Subject: [PATCH 1/2] bugfix: ensure tensor contiguous layout before protobuf serialization. NHWC (channels_last) tensors pass raw bytes directly into the proto buffer when serialized via torch_to_proto. On deserialization the reconstructed tensor assumes NCHW (contiguous) layout, which corrupts the channel order for bfloat16 and float16 image data. Force a contiguous copy before reading data_ptr() so the byte stream always matches the expected row-major order. --- xllm/core/util/utils.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 1e4757cd44..4d49f5136e 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -518,6 +518,8 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, return false; } + // Ensure contiguous (NCHW) layout before serializing raw bytes + auto contig_tensor = torch_tensor.contiguous(); torch::ScalarType torch_dtype = torch_tensor.scalar_type(); std::string proto_datatype = torch_datatype_to_proto(torch_dtype); if (proto_datatype.empty()) { @@ -568,7 +570,7 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, break; case torch::kBFloat16: { // Need to convert bfloat16 to uint8_t for storage - auto bfloat16_ptr = torch_tensor.data_ptr(); + auto bfloat16_ptr = contig_tensor.data_ptr(); uint8_t* uint8_ptr = reinterpret_cast(bfloat16_ptr); torch::Tensor uint8_tensor = torch::from_blob(uint8_ptr, @@ -581,7 +583,7 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, } case torch::kHalf: { // Need to convert float16 to uint8_t for storage - auto float16_ptr = torch_tensor.data_ptr(); + auto float16_ptr = contig_tensor.data_ptr(); uint8_t* uint8_ptr = reinterpret_cast(float16_ptr); torch::Tensor uint8_tensor = torch::from_blob( uint8_ptr, From bc833c82d5da70af5db74d75c1f7d5f74ecbb8c2 Mon Sep 17 00:00:00 2001 From: a120092009 Date: Mon, 18 May 2026 15:33:34 +0800 Subject: [PATCH 2/2] refactor: move contiguous call into bfloat16 and half serialization branches. Avoid redundant contiguous() for types where set_data_to_contents already handles the conversion, and use explicit torch::Tensor type per style guide. Co-Authored-By: Claude Opus 4.6 --- xllm/core/util/utils.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/xllm/core/util/utils.cpp b/xllm/core/util/utils.cpp index 4d49f5136e..987f4335f7 100644 --- a/xllm/core/util/utils.cpp +++ b/xllm/core/util/utils.cpp @@ -518,8 +518,6 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, return false; } - // Ensure contiguous (NCHW) layout before serializing raw bytes - auto contig_tensor = torch_tensor.contiguous(); torch::ScalarType torch_dtype = torch_tensor.scalar_type(); std::string proto_datatype = torch_datatype_to_proto(torch_dtype); if (proto_datatype.empty()) { @@ -569,12 +567,15 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, proto_contents, torch_tensor, proto_datatype); break; case torch::kBFloat16: { - // Need to convert bfloat16 to uint8_t for storage + // Need to convert bfloat16 to uint8_t for storage. + // Ensure contiguous layout first — data_ptr() reads raw memory + // which would be wrong for NHWC/ChannelsLast tensors. + torch::Tensor contig_tensor = torch_tensor.contiguous(); auto bfloat16_ptr = contig_tensor.data_ptr(); uint8_t* uint8_ptr = reinterpret_cast(bfloat16_ptr); torch::Tensor uint8_tensor = torch::from_blob(uint8_ptr, - {static_cast(torch_tensor.numel() * + {static_cast(contig_tensor.numel() * sizeof(torch::BFloat16))}, torch::dtype(torch::kUInt8)); data_set_success = set_data_to_contents( @@ -582,12 +583,15 @@ bool torch_to_proto(const torch::Tensor& torch_tensor, break; } case torch::kHalf: { - // Need to convert float16 to uint8_t for storage + // Need to convert float16 to uint8_t for storage. + // Ensure contiguous layout first — data_ptr() reads raw memory + // which would be wrong for NHWC/ChannelsLast tensors. + torch::Tensor contig_tensor = torch_tensor.contiguous(); auto float16_ptr = contig_tensor.data_ptr(); uint8_t* uint8_ptr = reinterpret_cast(float16_ptr); torch::Tensor uint8_tensor = torch::from_blob( uint8_ptr, - {static_cast(torch_tensor.numel() * sizeof(torch::Half))}, + {static_cast(contig_tensor.numel() * sizeof(torch::Half))}, torch::dtype(torch::kUInt8)); data_set_success = set_data_to_contents( proto_contents, uint8_tensor, proto_datatype);