Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions xllm/core/util/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,25 +567,31 @@ bool torch_to_proto(const torch::Tensor& torch_tensor,
proto_contents, torch_tensor, proto_datatype);
break;
case torch::kBFloat16: {
// Need to convert bfloat16 to uint8_t for storage
auto bfloat16_ptr = torch_tensor.data_ptr<torch::BFloat16>();
// Need to convert bfloat16 to uint8_t for storage.
// Ensure contiguous layout first — data_ptr() reads raw memory
// which would be wrong for NHWC/ChannelsLast tensors.
torch::Tensor contig_tensor = torch_tensor.contiguous();
auto bfloat16_ptr = contig_tensor.data_ptr<torch::BFloat16>();
Comment thread
a120092009 marked this conversation as resolved.
uint8_t* uint8_ptr = reinterpret_cast<uint8_t*>(bfloat16_ptr);
torch::Tensor uint8_tensor =
torch::from_blob(uint8_ptr,
{static_cast<int64_t>(torch_tensor.numel() *
{static_cast<int64_t>(contig_tensor.numel() *
sizeof(torch::BFloat16))},
torch::dtype(torch::kUInt8));
data_set_success = set_data_to_contents<uint8_t>(
proto_contents, uint8_tensor, proto_datatype);
break;
}
case torch::kHalf: {
// Need to convert float16 to uint8_t for storage
auto float16_ptr = torch_tensor.data_ptr<torch::Half>();
// Need to convert float16 to uint8_t for storage.
// Ensure contiguous layout first — data_ptr() reads raw memory
// which would be wrong for NHWC/ChannelsLast tensors.
torch::Tensor contig_tensor = torch_tensor.contiguous();
auto float16_ptr = contig_tensor.data_ptr<torch::Half>();
Comment thread
a120092009 marked this conversation as resolved.
uint8_t* uint8_ptr = reinterpret_cast<uint8_t*>(float16_ptr);
torch::Tensor uint8_tensor = torch::from_blob(
uint8_ptr,
{static_cast<int64_t>(torch_tensor.numel() * sizeof(torch::Half))},
{static_cast<int64_t>(contig_tensor.numel() * sizeof(torch::Half))},
torch::dtype(torch::kUInt8));
data_set_success = set_data_to_contents<uint8_t>(
proto_contents, uint8_tensor, proto_datatype);
Expand Down
Loading