From c3ba66f61dc4dc4799a12b279a1fc034b5f7d9a7 Mon Sep 17 00:00:00 2001 From: Super User Date: Mon, 25 May 2026 19:27:31 +0800 Subject: [PATCH] feat: support the v1/video/generation for wan22. Co-Authored-By: Claude Opus 4.7 --- CMakeLists.txt | 1 + vcpkg.json | 5 + xllm/CMakeLists.txt | 2 +- xllm/api_service/CMakeLists.txt | 2 + xllm/api_service/api_service.cpp | 47 +++ xllm/api_service/api_service.h | 12 + xllm/api_service/service_impl_factory.cpp | 2 + .../video_generation_service_impl.cpp | 103 +++++ .../video_generation_service_impl.h | 42 ++ xllm/core/framework/batch/dit_batch.cpp | 23 +- xllm/core/framework/request/CMakeLists.txt | 1 + xllm/core/framework/request/dit_request.cpp | 18 +- .../framework/request/dit_request_output.h | 14 +- .../framework/request/dit_request_params.cpp | 241 +++++++---- .../framework/request/dit_request_params.h | 6 +- .../framework/request/dit_request_state.h | 32 +- xllm/core/framework/request/mm_codec.cpp | 381 ++++++++++++++++++ xllm/core/framework/request/mm_codec.h | 13 + xllm/core/runtime/dit_forward_params.h | 44 +- .../runtime/forward_shared_memory_manager.cpp | 47 ++- xllm/core/runtime/params_utils.cpp | 35 ++ xllm/proto/CMakeLists.txt | 1 + xllm/proto/image_generation.proto | 3 + xllm/proto/video_generation.proto | 133 ++++++ xllm/proto/worker.proto | 23 +- xllm/proto/xllm_service.proto | 4 + xllm/server/xllm_server.cpp | 1 + 27 files changed, 1130 insertions(+), 106 deletions(-) create mode 100644 xllm/api_service/video_generation_service_impl.cpp create mode 100644 xllm/api_service/video_generation_service_impl.h create mode 100644 xllm/proto/video_generation.proto diff --git a/CMakeLists.txt b/CMakeLists.txt index 5971ff6c5e..426174c864 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,6 +275,7 @@ find_package(GTest CONFIG REQUIRED) find_package(benchmark CONFIG REQUIRED) find_package(nlohmann_json CONFIG REQUIRED) find_package(OpenCV CONFIG REQUIRED) +find_package(FFMPEG REQUIRED) find_package(Python COMPONENTS Development REQUIRED) find_package(pybind11 CONFIG REQUIRED) diff --git a/vcpkg.json b/vcpkg.json index d5a67b85fe..ef36b5d94a 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -71,6 +71,11 @@ "name": "snappy", "version>=": "1.1.10" }, + { + "name": "ffmpeg", + "default-features": false, + "features": ["avcodec", "avformat", "swscale", "swresample", "x264"] + }, { "name": "opencv4", "version>=": "4.7.0", diff --git a/xllm/CMakeLists.txt b/xllm/CMakeLists.txt index 5fdce0ba7a..b9fb9d15f8 100644 --- a/xllm/CMakeLists.txt +++ b/xllm/CMakeLists.txt @@ -92,7 +92,7 @@ else() endif() # link brpc -target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb protobuf::libprotobuf ${OpenCV_LIBS}) +target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb protobuf::libprotobuf ${OpenCV_LIBS} ${FFMPEG_LIBRARIES}) add_dependencies(xllm brpc-static) if(USE_NPU) diff --git a/xllm/api_service/CMakeLists.txt b/xllm/api_service/CMakeLists.txt index df025345e5..f42c9ea73d 100644 --- a/xllm/api_service/CMakeLists.txt +++ b/xllm/api_service/CMakeLists.txt @@ -16,6 +16,7 @@ cc_library( embedding_service_impl.h audio_generation_service_impl.h image_generation_service_impl.h + video_generation_service_impl.h rerank_service_impl.h qwen3_rerank_service_impl.h non_stream_call.h @@ -40,6 +41,7 @@ cc_library( embedding_service_impl.cpp audio_generation_service_impl.cpp image_generation_service_impl.cpp + video_generation_service_impl.cpp models_service_impl.cpp rerank_service_impl.cpp stream_output_parser.cpp diff --git a/xllm/api_service/api_service.cpp b/xllm/api_service/api_service.cpp index 12d9e833df..1a0a380655 100644 --- a/xllm/api_service/api_service.cpp +++ b/xllm/api_service/api_service.cpp @@ -42,6 +42,7 @@ limitations under the License. #include "image_generation.pb.h" #include "models.pb.h" #include "service_impl_factory.h" +#include "video_generation.pb.h" #include "xllm_metrics.h" namespace xllm { @@ -533,6 +534,52 @@ void APIService::AudioGenerationHttp( audio_generation_service_impl_->process_async(call); } +void APIService::VideoGeneration(::google::protobuf::RpcController* controller, + const proto::VideoGenerationRequest* request, + proto::VideoGenerationResponse* response, + ::google::protobuf::Closure* done) { + // TODO with xllm-service +} + +void APIService::VideoGenerationHttp( + ::google::protobuf::RpcController* controller, + const proto::HttpRequest* request, + proto::HttpResponse* response, + ::google::protobuf::Closure* done) { + xllm::ClosureGuard done_guard( + done, + std::bind(request_in_metric, nullptr), + std::bind(request_out_metric, (void*)controller)); + if (!request || !response || !controller) { + LOG(ERROR) << "brpc request | respose | controller is null"; + return; + } + + auto arena = GetArenaWithCheck(response); + auto req_pb = + google::protobuf::Arena::CreateMessage( + arena); + auto resp_pb = + google::protobuf::Arena::CreateMessage( + arena); + + auto ctrl = reinterpret_cast(controller); + std::string error; + json2pb::Json2PbOptions options; + butil::IOBuf& buf = ctrl->request_attachment(); + butil::IOBufAsZeroCopyInputStream iobuf_stream(buf); + auto st = json2pb::JsonToProtoMessage(&iobuf_stream, req_pb, options, &error); + if (!st) { + ctrl->SetFailed(error); + LOG(ERROR) << "parse json to proto failed: " << error; + return; + } + std::shared_ptr call = + std::make_shared( + ctrl, done_guard.release(), req_pb, resp_pb, arena != nullptr); + video_generation_service_impl_->process_async(call); +} + void APIService::Rerank(::google::protobuf::RpcController* controller, const proto::RerankRequest* request, proto::RerankResponse* response, diff --git a/xllm/api_service/api_service.h b/xllm/api_service/api_service.h index 48560b8f39..e4b97fff17 100644 --- a/xllm/api_service/api_service.h +++ b/xllm/api_service/api_service.h @@ -31,6 +31,7 @@ limitations under the License. #include "rec_completion_service_impl.h" #include "rerank_service_impl.h" #include "sample_service_impl.h" +#include "video_generation_service_impl.h" #include "xllm_service.pb.h" namespace xllm { @@ -107,6 +108,16 @@ class APIService : public proto::XllmAPIService { proto::HttpResponse* response, ::google::protobuf::Closure* done) override; + void VideoGeneration(::google::protobuf::RpcController* controller, + const proto::VideoGenerationRequest* request, + proto::VideoGenerationResponse* response, + ::google::protobuf::Closure* done) override; + + void VideoGenerationHttp(::google::protobuf::RpcController* controller, + const proto::HttpRequest* request, + proto::HttpResponse* response, + ::google::protobuf::Closure* done) override; + void Rerank(::google::protobuf::RpcController* controller, const proto::RerankRequest* request, proto::RerankResponse* response, @@ -216,6 +227,7 @@ class APIService : public proto::XllmAPIService { std::unique_ptr models_service_impl_; std::unique_ptr image_generation_service_impl_; std::unique_ptr audio_generation_service_impl_; + std::unique_ptr video_generation_service_impl_; std::unique_ptr rerank_service_impl_; std::unique_ptr rec_completion_service_impl_; }; diff --git a/xllm/api_service/service_impl_factory.cpp b/xllm/api_service/service_impl_factory.cpp index b7140f3918..99c8ea7081 100644 --- a/xllm/api_service/service_impl_factory.cpp +++ b/xllm/api_service/service_impl_factory.cpp @@ -93,6 +93,8 @@ void ServiceImplFactory::create( std::make_unique(dit_master, models); self->audio_generation_service_impl_ = std::make_unique(dit_master, models); + self->video_generation_service_impl_ = + std::make_unique(dit_master, models); }}, {static_cast(ServingMode::REC), [](APIService* self, diff --git a/xllm/api_service/video_generation_service_impl.cpp b/xllm/api_service/video_generation_service_impl.cpp new file mode 100644 index 0000000000..17d41c88f0 --- /dev/null +++ b/xllm/api_service/video_generation_service_impl.cpp @@ -0,0 +1,103 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "video_generation_service_impl.h" + +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "api_service/utils.h" +#include "core/framework/request/dit_request_params.h" +#include "distributed_runtime/dit_master.h" + +namespace xllm { + +namespace { + +bool send_result_to_client_brpc(std::shared_ptr call, + const std::string& request_id, + int64_t created_time, + const std::string& model, + const DiTRequestOutput& req_output) { + auto& response = call->response(); + response.set_object("list"); + response.set_id(request_id); + response.set_created(created_time); + response.set_model(model); + auto* proto_output = response.mutable_output(); + const std::vector& outputs = req_output.outputs; + proto_output->mutable_results()->Reserve(outputs.size()); + + std::string video; + for (const auto& output : outputs) { + auto* proto_result = proto_output->add_results(); + + video.clear(); + butil::Base64Encode(output.image, &video); + proto_result->set_video(video); + + proto_result->set_width(output.width); + proto_result->set_height(output.height); + proto_result->set_seed(output.seed); + proto_result->set_num_frames(output.num_frames); + proto_result->set_fps(output.video_fps); + } + return call->write_and_finish(response); +} + +} // namespace + +VideoGenerationServiceImpl::VideoGenerationServiceImpl( + DiTMaster* master, + const std::vector& models) + : APIServiceImpl(models), master_{master} { + CHECK(master_ != nullptr); +} + +void VideoGenerationServiceImpl::process_async_impl( + std::shared_ptr call) { + const auto& rpc_request = call->request(); + const auto& model = rpc_request.model(); + if (!models_.contains(model)) { + call->finish_with_error(StatusCode::UNKNOWN, "Model not supported"); + return; + } + + DiTRequestParams request_params( + rpc_request, call->get_x_request_id(), call->get_x_request_time()); + + std::string saved_request_id = request_params.request_id; + master_->handle_request( + std::move(request_params), + call.get(), + [call, + model, + request_id = std::move(saved_request_id), + created_time = absl::ToUnixSeconds(absl::Now())]( + const DiTRequestOutput& req_output) -> bool { + if (req_output.status.has_value()) { + const auto& status = req_output.status.value(); + if (!status.ok()) { + return call->finish_with_error(status.code(), status.message()); + } + } + + return send_result_to_client_brpc( + call, request_id, created_time, model, req_output); + }); +} + +} // namespace xllm diff --git a/xllm/api_service/video_generation_service_impl.h b/xllm/api_service/video_generation_service_impl.h new file mode 100644 index 0000000000..cb61a65357 --- /dev/null +++ b/xllm/api_service/video_generation_service_impl.h @@ -0,0 +1,42 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include + +#include "api_service/api_service_impl.h" +#include "api_service/non_stream_call.h" +#include "video_generation.pb.h" + +namespace xllm { + +using VideoGenerationCall = NonStreamCall; +class DiTMaster; +// Handles /v1/video/generation requests +class VideoGenerationServiceImpl final + : public APIServiceImpl { + public: + VideoGenerationServiceImpl(DiTMaster* master, + const std::vector& models); + + void process_async_impl(std::shared_ptr call); + + private: + DISALLOW_COPY_AND_ASSIGN(VideoGenerationServiceImpl); + DiTMaster* master_ = nullptr; +}; + +} // namespace xllm diff --git a/xllm/core/framework/batch/dit_batch.cpp b/xllm/core/framework/batch/dit_batch.cpp index 67a373fcf8..3d1f9cbe55 100644 --- a/xllm/core/framework/batch/dit_batch.cpp +++ b/xllm/core/framework/batch/dit_batch.cpp @@ -63,9 +63,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() { std::vector images; std::vector mask_images; std::vector control_images; + std::vector condition_images; std::vector latents; std::vector masked_image_latents; - + std::vector last_images; + std::vector image_embeds; const auto batch_size = request_vec_.size(); prompt_embeds.reserve(batch_size); pooled_prompt_embeds.reserve(batch_size); @@ -74,13 +76,15 @@ DiTForwardInput DiTBatch::prepare_forward_input() { images.reserve(batch_size); mask_images.reserve(batch_size); control_images.reserve(batch_size); + condition_images.reserve(batch_size); latents.reserve(batch_size); masked_image_latents.reserve(batch_size); + last_images.reserve(batch_size); + image_embeds.reserve(batch_size); std::vector images_list; size_t images_size = request_vec_[0]->state().input_params().images.size(); bool images_size_valid = images_size > 0; - for (const auto& request : request_vec_) { const auto& generation_params = request->state().generation_params(); if (input.generation_params != generation_params) { @@ -113,6 +117,9 @@ DiTForwardInput DiTBatch::prepare_forward_input() { images.emplace_back(input_params.image); mask_images.emplace_back(input_params.mask_image); control_images.emplace_back(input_params.control_image); + condition_images.emplace_back(input_params.condition_image); + last_images.emplace_back(input_params.last_image); + image_embeds.emplace_back(input_params.image_embeds); if (input_params.images.size() != images_size) { images_size_valid = false; @@ -180,6 +187,10 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.control_image = torch::stack(control_images); } + if (check_tensors_valid(condition_images)) { + input.condition_images = torch::stack(condition_images); + } + if (check_tensors_valid(prompt_embeds)) { input.prompt_embeds = torch::stack(prompt_embeds); } @@ -204,6 +215,14 @@ DiTForwardInput DiTBatch::prepare_forward_input() { if (check_tensors_valid(masked_image_latents)) { input.masked_image_latents = torch::stack(masked_image_latents); } + + if (check_tensors_valid(last_images)) { + input.last_images = torch::stack(last_images); + } + + if (check_tensors_valid(image_embeds)) { + input.image_embeds = torch::stack(image_embeds); + } return input; } diff --git a/xllm/core/framework/request/CMakeLists.txt b/xllm/core/framework/request/CMakeLists.txt index 19b57d8773..4cd35456df 100644 --- a/xllm/core/framework/request/CMakeLists.txt +++ b/xllm/core/framework/request/CMakeLists.txt @@ -74,4 +74,5 @@ cc_library( proto::xllm_proto torch ${OpenCV_LIBS} + ${FFMPEG_LIBRARIES} ) diff --git a/xllm/core/framework/request/dit_request.cpp b/xllm/core/framework/request/dit_request.cpp index bfa9e4463c..09ecdfb4ad 100644 --- a/xllm/core/framework/request/dit_request.cpp +++ b/xllm/core/framework/request/dit_request.cpp @@ -117,9 +117,11 @@ void DiTRequest::log_statistic(double total_latency) { void DiTRequest::handle_forward_output(torch::Tensor output) { // Pipeline already chunks by batch size along dim 0 before calling here. // For image models, also split by num_images_per_prompt. + // For video models, split by num_images_per_prompt * num_videos_per_prompt. // For audio models, num_images_per_prompt defaults to 1 so this is a no-op. const int32_t count = - static_cast(state_.generation_params().num_images_per_prompt); + static_cast(state_.generation_params().num_images_per_prompt * + state_.generation_params().num_videos_per_prompt); output_.tensors = torch::chunk(output, count); } @@ -142,8 +144,10 @@ const DiTRequestOutput DiTRequest::generate_output() { } const int32_t count = - static_cast(state_.generation_params().num_images_per_prompt); + static_cast(state_.generation_params().num_images_per_prompt * + state_.generation_params().num_videos_per_prompt); OpenCVImageEncoder image_encoder; + FFmpegVideoEncoder video_encoder; for (size_t idx = 0; idx < count; ++idx) { torch::Tensor output_tensor = output_.tensors[idx].squeeze(0).cpu().to(torch::kFloat32).contiguous(); @@ -152,6 +156,16 @@ const DiTRequestOutput DiTRequest::generate_output() { encode_wav(samples, state_.generation_params().audio_sampling_rate, result.audio); + } else if (output_tensor.dim() == 4 || + state_.generation_params().force_video_output) { + video_encoder.encode(output_tensor, + state_.generation_params().video_fps, + "mp4", + result.image); + result.num_frames = output_tensor.dim() == 4 + ? static_cast(output_tensor.size(0)) + : 0; + result.video_fps = state_.generation_params().video_fps; } else { image_encoder.encode(output_tensor, result.image); } diff --git a/xllm/core/framework/request/dit_request_output.h b/xllm/core/framework/request/dit_request_output.h index 1eb7e8ada6..9145d9baeb 100644 --- a/xllm/core/framework/request/dit_request_output.h +++ b/xllm/core/framework/request/dit_request_output.h @@ -31,20 +31,26 @@ struct DiTGenerationOutput { // the index of the sequence in the request. size_t index; - // the generated image in torch tensor format. + // the generated image/video as base64-encoded data. std::string image; // the generated audio as raw WAV bytes (audio models only). std::string audio; - // the height of the generated image. + // the height of the generated image/video. int32_t height; - // the width of the generated image. + // the width of the generated image/video. int32_t width; // seed used for generation. int64_t seed; + + // number of video frames + int32_t num_frames = 0; + + // video fps + double video_fps = 0.0; }; struct DiTRequestOutput { @@ -82,4 +88,4 @@ using DiTOutputCallback = std::function; using BatchDiTOutputCallback = std::function; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/request/dit_request_params.cpp b/xllm/core/framework/request/dit_request_params.cpp index 73c7b0fa26..cdce7dbdd1 100644 --- a/xllm/core/framework/request/dit_request_params.cpp +++ b/xllm/core/framework/request/dit_request_params.cpp @@ -29,8 +29,8 @@ namespace xllm { namespace { thread_local ShortUUID short_uuid; -std::string generate_image_generation_request_id() { - return "imggen-" + InstanceName::name()->get_name_hash() + "-" + +std::string generate_request_id(const std::string& prefix) { + return prefix + InstanceName::name()->get_name_hash() + "-" + short_uuid.random(); } @@ -72,115 +72,110 @@ bool decode_prompt_audio(const std::string& b64_audio, } // namespace -std::pair splitResolution(const std::string& s) { +std::pair split_resolution(const std::string& s) { size_t pos = s.find('*'); int width = std::stoi(s.substr(0, pos)); int height = std::stoi(s.substr(pos + 1)); return {width, height}; } -DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, - const std::string& x_rid, - const std::string& x_rtime) { - if (request.has_request_id()) { - request_id = request.request_id(); - } else { - request_id = generate_image_generation_request_id(); +// Decode a base64-encoded image string into a torch tensor via OpenCV. +bool decode_base64_image(const std::string& base64, torch::Tensor& out) { + std::string raw_bytes; + if (!butil::Base64Decode(base64, &raw_bytes)) { + LOG(ERROR) << "Base64 decode failed"; + return false; } - x_request_id = x_rid; - x_request_time = x_rtime; - - model = request.model(); + OpenCVImageDecoder decoder; + if (!decoder.decode(raw_bytes, out)) { + LOG(ERROR) << "Image decode failed"; + return false; + } + return true; +} - // input params - const auto& input = request.input(); +// Shared helper: populate DiTInputParams from an image-like input proto. +template +void fill_input_params(DiTInputParams& input_params, const InputProto& input) { + // Fields common to both proto::Input and proto::VideoInput input_params.prompt = input.prompt(); - if (input.has_prompt_2()) { - input_params.prompt_2 = input.prompt_2(); - } if (input.has_negative_prompt()) { input_params.negative_prompt = input.negative_prompt(); } - if (input.has_negative_prompt_2()) { - input_params.negative_prompt_2 = input.negative_prompt_2(); - } - if (input.has_prompt_embed()) { input_params.prompt_embed = util::proto_to_torch(input.prompt_embed()); } - if (input.has_pooled_prompt_embed()) { - input_params.pooled_prompt_embed = - util::proto_to_torch(input.pooled_prompt_embed()); - } if (input.has_negative_prompt_embed()) { input_params.negative_prompt_embed = util::proto_to_torch(input.negative_prompt_embed()); } - if (input.has_negative_pooled_prompt_embed()) { - input_params.negative_pooled_prompt_embed = - util::proto_to_torch(input.negative_pooled_prompt_embed()); - } - if (input.has_latent()) { - input_params.latent = util::proto_to_torch(input.latent()); - } - - if (input.has_masked_image_latent()) { - input_params.masked_image_latent = - util::proto_to_torch(input.masked_image_latent()); + if (input.has_image()) { + decode_base64_image(input.image(), input_params.image); } - OpenCVImageDecoder decoder; - if (input.has_image()) { - std::string raw_bytes; - if (!butil::Base64Decode(input.image(), &raw_bytes)) { - LOG(ERROR) << "Base64 image decode failed"; + // Image-only fields (proto::Input) + if constexpr (std::is_same_v) { + if (input.has_prompt_2()) { + input_params.prompt_2 = input.prompt_2(); } - if (!decoder.decode(raw_bytes, input_params.image)) { - LOG(ERROR) << "Image decode failed."; + if (input.has_negative_prompt_2()) { + input_params.negative_prompt_2 = input.negative_prompt_2(); } - } - - input_params.images.reserve(input.images().size()); - for (const auto& image : input.images()) { - std::string binary; - if (!butil::Base64Decode(image, &binary)) { - LOG(ERROR) << "Base64 image decode failed"; - continue; + if (input.has_pooled_prompt_embed()) { + input_params.pooled_prompt_embed = + util::proto_to_torch(input.pooled_prompt_embed()); } - torch::Tensor tensor; - if (!decoder.decode(binary, tensor)) { - LOG(ERROR) << "Image decode failed."; - continue; + if (input.has_negative_pooled_prompt_embed()) { + input_params.negative_pooled_prompt_embed = + util::proto_to_torch(input.negative_pooled_prompt_embed()); } - input_params.images.emplace_back(std::move(tensor)); - } - - if (input.has_mask_image()) { - std::string raw_bytes; - if (!butil::Base64Decode(input.mask_image(), &raw_bytes)) { - LOG(ERROR) << "Base64 mask_image decode failed"; + if (input.has_latent()) { + input_params.latent = util::proto_to_torch(input.latent()); + } + if (input.has_masked_image_latent()) { + input_params.masked_image_latent = + util::proto_to_torch(input.masked_image_latent()); } - if (!decoder.decode(raw_bytes, input_params.mask_image)) { - LOG(ERROR) << "Mask_image decode failed."; + if (input.has_mask_image()) { + decode_base64_image(input.mask_image(), input_params.mask_image); + } + input_params.images.reserve(input.images().size()); + for (const auto& image : input.images()) { + torch::Tensor tensor; + if (!decode_base64_image(image, tensor)) { + continue; + } + input_params.images.emplace_back(std::move(tensor)); + } + if (input.has_condition_image()) { + decode_base64_image(input.condition_image(), + input_params.condition_image); + } + if (input.has_control_image()) { + decode_base64_image(input.control_image(), input_params.control_image); } } - if (input.has_control_image()) { - std::string raw_bytes; - if (!butil::Base64Decode(input.control_image(), &raw_bytes)) { - LOG(ERROR) << "Base64 control_image decode failed"; + // Video-only fields (proto::VideoInput) + if constexpr (std::is_same_v) { + if (input.has_last_image()) { + decode_base64_image(input.last_image(), input_params.last_image); } - if (!decoder.decode(raw_bytes, input_params.control_image)) { - LOG(ERROR) << "Control_image decode failed."; + if (input.has_image_embeds()) { + input_params.image_embeds = util::proto_to_torch(input.image_embeds()); } } +} - // generation params - const auto& params = request.parameters(); +// Shared helper: populate generation params from a parameters proto. +// Both ImageParameters and VideoParameters share most fields. +template +void fill_generation_params(DiTGenerationParams& generation_params, + const ParamsProto& params) { if (params.has_size()) { - auto size = splitResolution(params.size()); - generation_params.width = size.first; - generation_params.height = size.second; + auto [w, h] = split_resolution(params.size()); + generation_params.width = w; + generation_params.height = h; } if (params.has_num_inference_steps()) { generation_params.num_inference_steps = params.num_inference_steps(); @@ -191,26 +186,52 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, if (params.has_guidance_scale()) { generation_params.guidance_scale = params.guidance_scale(); } - if (params.has_num_images_per_prompt()) { - generation_params.num_images_per_prompt = - static_cast(params.num_images_per_prompt()); - } else { - generation_params.num_images_per_prompt = 1; - } if (params.has_seed()) { generation_params.seed = params.seed(); } if (params.has_max_sequence_length()) { generation_params.max_sequence_length = params.max_sequence_length(); } - if (params.has_enable_cfg_renorm()) { - generation_params.enable_cfg_renorm = params.enable_cfg_renorm(); + if constexpr (std::is_same_v) { + if (params.has_enable_cfg_renorm()) { + generation_params.enable_cfg_renorm = params.enable_cfg_renorm(); + } + if (params.has_cfg_renorm_min()) { + generation_params.cfg_renorm_min = params.cfg_renorm_min(); + } + } +} + +// --------------------------------------------------------------------------- +// Image generation constructor +// --------------------------------------------------------------------------- +DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, + const std::string& x_rid, + const std::string& x_rtime) { + request_id = request.has_request_id() ? request.request_id() + : generate_request_id("imggen-"); + x_request_id = x_rid; + x_request_time = x_rtime; + model = request.model(); + + if (request.has_input()) { + fill_input_params(input_params, request.input()); } - if (params.has_cfg_renorm_min()) { - generation_params.cfg_renorm_min = params.cfg_renorm_min(); + + generation_params.num_images_per_prompt = 1; + if (request.has_parameters()) { + const auto& params = request.parameters(); + fill_generation_params(generation_params, params); + if (params.has_num_images_per_prompt()) { + generation_params.num_images_per_prompt = + static_cast(params.num_images_per_prompt()); + } } } +// --------------------------------------------------------------------------- +// Audio generation constructor +// --------------------------------------------------------------------------- DiTRequestParams::DiTRequestParams(const proto::AudioGenerationRequest& request, const std::string& x_rid, const std::string& x_rtime) { @@ -262,6 +283,52 @@ DiTRequestParams::DiTRequestParams(const proto::AudioGenerationRequest& request, } } +// --------------------------------------------------------------------------- +// Video generation constructor +// --------------------------------------------------------------------------- +DiTRequestParams::DiTRequestParams(const proto::VideoGenerationRequest& request, + const std::string& x_rid, + const std::string& x_rtime) { + request_id = request.has_request_id() ? request.request_id() + : generate_request_id("vidgen-"); + x_request_id = x_rid; + x_request_time = x_rtime; + model = request.model(); + + generation_params.force_video_output = true; + + if (request.has_input()) { + fill_input_params(input_params, request.input()); + } + + if (request.has_parameters()) { + const auto& params = request.parameters(); + fill_generation_params(generation_params, params); + if (params.has_num_videos_per_prompt()) { + generation_params.num_videos_per_prompt = + static_cast(params.num_videos_per_prompt()); + } + if (params.has_num_frames()) { + generation_params.num_frames = params.num_frames(); + } + if (params.has_fps()) { + generation_params.video_fps = params.fps(); + } + if (params.has_guidance_scale_2()) { + generation_params.guidance_scale_2 = params.guidance_scale_2(); + } + if (params.has_seconds()) { + generation_params.seconds = params.seconds(); + } + if (params.has_boundary_ratio()) { + generation_params.boundary_ratio = params.boundary_ratio(); + } + if (params.has_flow_shift()) { + generation_params.flow_shift = params.flow_shift(); + } + } +} + bool DiTRequestParams::verify_params( std::function callback) const { if (input_params.prompt.empty() && !input_params.prompt_embed.defined()) { diff --git a/xllm/core/framework/request/dit_request_params.h b/xllm/core/framework/request/dit_request_params.h index 8905481860..332983512e 100644 --- a/xllm/core/framework/request/dit_request_params.h +++ b/xllm/core/framework/request/dit_request_params.h @@ -13,6 +13,7 @@ #include "image_generation.pb.h" #include "request.h" #include "tensor.pb.h" +#include "video_generation.pb.h" namespace xllm { struct DiTRequestParams { @@ -23,6 +24,9 @@ struct DiTRequestParams { DiTRequestParams(const proto::AudioGenerationRequest& request, const std::string& x_rid, const std::string& x_rtime); + DiTRequestParams(const proto::VideoGenerationRequest& request, + const std::string& x_rid, + const std::string& x_rtime); bool verify_params(DiTOutputCallback callback) const; @@ -39,4 +43,4 @@ struct DiTRequestParams { DiTGenerationParams generation_params; }; -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/framework/request/dit_request_state.h b/xllm/core/framework/request/dit_request_state.h index 1796605656..5edc2fd42e 100644 --- a/xllm/core/framework/request/dit_request_state.h +++ b/xllm/core/framework/request/dit_request_state.h @@ -47,7 +47,14 @@ struct DiTGenerationParams { audio_duration_frames == other.audio_duration_frames && audio_steps == other.audio_steps && audio_guidance_method == other.audio_guidance_method && - audio_sampling_rate == other.audio_sampling_rate; + audio_sampling_rate == other.audio_sampling_rate && + num_videos_per_prompt == other.num_videos_per_prompt && + num_frames == other.num_frames && + force_video_output == other.force_video_output && + video_fps == other.video_fps && + guidance_scale_2 == other.guidance_scale_2 && + seconds == other.seconds && boundary_ratio == other.boundary_ratio && + flow_shift == other.flow_shift; } bool operator!=(const DiTGenerationParams& other) const { @@ -66,6 +73,8 @@ struct DiTGenerationParams { uint32_t num_images_per_prompt = 1; + uint32_t num_videos_per_prompt = 1; + int64_t seed = 0; int32_t max_sequence_length = 512; @@ -88,6 +97,20 @@ struct DiTGenerationParams { // Audio sample rate in Hz, read from model config.json (sampling_rate). int32_t audio_sampling_rate = 24000; + + int32_t num_frames = 81; + + bool force_video_output = false; + + double video_fps = 8.0; + + float guidance_scale_2 = 1.0; + + int32_t seconds = 5; + + float boundary_ratio = 0.9f; + + float flow_shift = 1.0f; }; struct DiTInputParams { @@ -119,10 +142,17 @@ struct DiTInputParams { torch::Tensor control_image; + torch::Tensor condition_image; + torch::Tensor mask_image; torch::Tensor masked_image_latent; + // Video-specific input fields + torch::Tensor last_image; + + torch::Tensor image_embeds; + // Prompt audio for voice cloning (LongCat-AudioDiT). // Float32 PCM, shape (1, num_samples), mono 24 kHz. torch::Tensor prompt_audio; diff --git a/xllm/core/framework/request/mm_codec.cpp b/xllm/core/framework/request/mm_codec.cpp index 696fb4e3d4..a2fcb41f72 100644 --- a/xllm/core/framework/request/mm_codec.cpp +++ b/xllm/core/framework/request/mm_codec.cpp @@ -15,11 +15,14 @@ limitations under the License. ==============================================================================*/ #include "mm_codec.h" +#include #include + extern "C" { #include #include #include +#include #include #include #include @@ -635,4 +638,382 @@ bool FFmpegAudioDecoder::decode(const std::string& raw_data, return true; } +// ---- MemoryMediaWriter (in-memory encoding base class) ---- + +namespace { + +struct MemWriteCtx { + std::vector* buf; + int64_t pos; +}; + +struct Writer { + static int32_t write(void* opaque, uint8_t* buf, int32_t buf_size) { + auto* mc = static_cast(opaque); + int64_t end_pos = mc->pos + buf_size; + if (end_pos > static_cast(mc->buf->size())) { + mc->buf->resize(static_cast(end_pos), 0); + } + std::memcpy(mc->buf->data() + mc->pos, buf, static_cast(buf_size)); + mc->pos = end_pos; + return buf_size; + } + + static int64_t seek(void* opaque, int64_t offset, int32_t whence) { + auto* mc = static_cast(opaque); + if (whence == AVSEEK_SIZE) { + return static_cast(mc->buf->size()); + } + int64_t pos = 0; + switch (whence) { + case SEEK_SET: + pos = offset; + break; + case SEEK_CUR: + pos = mc->pos + offset; + break; + case SEEK_END: + pos = static_cast(mc->buf->size()) + offset; + break; + default: + return AVERROR(EINVAL); + } + if (pos < 0) { + return AVERROR(EINVAL); + } + mc->pos = pos; + return pos; + } +}; + +} // namespace + +class MemoryMediaWriter { + public: + MemoryMediaWriter() = default; + + virtual ~MemoryMediaWriter() { + if (pkt_) { + av_packet_free(&pkt_); + } + if (codec_ctx_) { + avcodec_free_context(&codec_ctx_); + } + if (fmt_ctx_) { + if (!finished_) { + av_write_trailer(fmt_ctx_); + } + avformat_free_context(fmt_ctx_); + } + if (avio_ctx_) { + av_freep(&avio_ctx_->buffer); + avio_context_free(&avio_ctx_); + } + } + + protected: + bool init(const char* format, + AVCodecID codec_id, + int32_t width, + int32_t height, + double fps, + AVPixelFormat pix_fmt, + AVDictionary** opts = nullptr) { + const AVCodec* codec = avcodec_find_encoder(codec_id); + if (!codec) { + LOG(ERROR) << "MemoryMediaWriter: encoder not found, codec_id=" + << avcodec_get_name(codec_id); + return false; + } + + constexpr int32_t avio_buf_sz = 1 << 16; + uint8_t* avio_buf = + static_cast(av_malloc(static_cast(avio_buf_sz))); + if (!avio_buf) { + return false; + } + + avio_ctx_ = avio_alloc_context(avio_buf, + avio_buf_sz, + 1, + &write_ctx_, + nullptr, + &Writer::write, + &Writer::seek); + if (!avio_ctx_) { + av_freep(reinterpret_cast(&avio_buf)); + return false; + } + avio_ctx_->seekable = AVIO_SEEKABLE_NORMAL; + + const AVOutputFormat* fmt = av_guess_format(format, nullptr, nullptr); + if (!fmt) { + LOG(ERROR) << "MemoryMediaWriter: no muxer for " << format; + return false; + } + + if (avformat_alloc_output_context2(&fmt_ctx_, fmt, nullptr, nullptr) < 0 || + !fmt_ctx_) { + return false; + } + fmt_ctx_->pb = avio_ctx_; + fmt_ctx_->flags |= AVFMT_FLAG_CUSTOM_IO; + + codec_ctx_ = avcodec_alloc_context3(codec); + if (!codec_ctx_) { + return false; + } + + codec_ctx_->width = width; + codec_ctx_->height = height; + codec_ctx_->time_base = {1, static_cast(std::llround(fps))}; + codec_ctx_->framerate = {static_cast(std::llround(fps)), 1}; + codec_ctx_->pix_fmt = pix_fmt; + + if (fmt_ctx_->oformat->flags & AVFMT_GLOBALHEADER) { + codec_ctx_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER; + } + + if (avcodec_open2(codec_ctx_, codec, opts) < 0) { + LOG(ERROR) << "MemoryMediaWriter: avcodec_open2 failed for " + << codec->name; + return false; + } + + stream_ = avformat_new_stream(fmt_ctx_, nullptr); + if (!stream_) { + return false; + } + stream_->time_base = codec_ctx_->time_base; + if (avcodec_parameters_from_context(stream_->codecpar, codec_ctx_) < 0) { + return false; + } + + if (avformat_write_header(fmt_ctx_, nullptr) < 0) { + LOG(ERROR) << "MemoryMediaWriter: avformat_write_header failed"; + return false; + } + + pkt_ = av_packet_alloc(); + if (!pkt_) { + return false; + } + + LOG(INFO) << "MemoryMediaWriter: initialized " << codec->name << " [" + << format << "] " << width << "x" << height << " @ " << fps + << " fps"; + return true; + } + + bool send_frame(AVFrame* frame) { + if (avcodec_send_frame(codec_ctx_, frame) < 0) { + return false; + } + return drain_packets(); + } + + bool finish() { + avcodec_send_frame(codec_ctx_, nullptr); + if (!drain_packets()) { + return false; + } + av_write_trailer(fmt_ctx_); + finished_ = true; + return true; + } + + std::vector take_output() { return std::move(out_buf_); } + + AVCodecContext* codec_ctx() { return codec_ctx_; } + AVStream* stream() { return stream_; } + + bool drain_packets() { + while (avcodec_receive_packet(codec_ctx_, pkt_) == 0) { + av_packet_rescale_ts(pkt_, codec_ctx_->time_base, stream_->time_base); + pkt_->stream_index = stream_->index; + if (av_interleaved_write_frame(fmt_ctx_, pkt_) < 0) { + av_packet_unref(pkt_); + return false; + } + av_packet_unref(pkt_); + } + return true; + } + + AVFormatContext* fmt_ctx_ = nullptr; + AVIOContext* avio_ctx_ = nullptr; + AVCodecContext* codec_ctx_ = nullptr; + AVPacket* pkt_ = nullptr; + AVStream* stream_ = nullptr; + MemWriteCtx write_ctx_{&out_buf_, 0}; + std::vector out_buf_; + bool finished_ = false; +}; + +class MemoryVideoWriter final : public MemoryMediaWriter { + public: + MemoryVideoWriter() = default; + + ~MemoryVideoWriter() { + if (sws_ctx_) { + sws_freeContext(sws_ctx_); + } + if (yuv_frame_) { + av_frame_free(&yuv_frame_); + } + } + + bool write(const torch::Tensor& video, + double fps, + const std::string& format, + std::string& raw_data) { + if (video.dim() != 4 || video.size(1) != 3) { + LOG(ERROR) << "MemoryVideoWriter: expects [T,C,H,W] with C=3, got " + << video.sizes(); + return false; + } + if (video.scalar_type() != torch::kFloat32 || !video.device().is_cpu()) { + LOG(ERROR) << "MemoryVideoWriter: expects cpu float32 tensor"; + return false; + } + + const int64_t T = video.size(0); + const int64_t H = video.size(2); + const int64_t W = video.size(3); + if (T == 0 || H == 0 || W == 0) { + LOG(ERROR) << "MemoryVideoWriter: empty dimensions T=" << T << " H=" << H + << " W=" << W; + return false; + } + + AVCodecID codec_id; + AVPixelFormat pix_fmt; + AVDictionary* opts = nullptr; + + if (format == "avi") { + codec_id = AV_CODEC_ID_MJPEG; + pix_fmt = AV_PIX_FMT_YUVJ420P; + } else { + const AVCodec* x264_codec = avcodec_find_encoder_by_name("libx264"); + + if (x264_codec) { + codec_id = x264_codec->id; + pix_fmt = AV_PIX_FMT_YUV420P; + av_dict_set(&opts, "crf", "18", 0); + av_dict_set(&opts, "preset", "medium", 0); + av_dict_set(&opts, "profile", "high", 0); + av_dict_set(&opts, "level", "4.1", 0); + LOG(INFO) << "Using libx264 H.264 encoder with CRF=18"; + } else { + codec_id = AV_CODEC_ID_MPEG4; + pix_fmt = AV_PIX_FMT_YUV420P; + av_dict_set(&opts, "mbd", "2", 0); + LOG(WARNING) << "libx264 not available, using MPEG4 fallback"; + } + } + + if (!init(format.c_str(), + codec_id, + static_cast(W), + static_cast(H), + fps, + pix_fmt, + &opts)) { + if (opts) av_dict_free(&opts); + return false; + } + if (opts) av_dict_free(&opts); + + sws_ctx_ = sws_getContext(static_cast(W), + static_cast(H), + AV_PIX_FMT_RGB24, + static_cast(W), + static_cast(H), + pix_fmt, + SWS_BILINEAR, + nullptr, + nullptr, + nullptr); + if (!sws_ctx_) { + LOG(ERROR) << "MemoryVideoWriter: sws_getContext failed"; + return false; + } + + yuv_frame_ = av_frame_alloc(); + if (!yuv_frame_) { + return false; + } + yuv_frame_->format = pix_fmt; + yuv_frame_->width = static_cast(W); + yuv_frame_->height = static_cast(H); + if (av_frame_get_buffer(yuv_frame_, 0) < 0) { + return false; + } + + auto video_acc = video.accessor(); + const int64_t stride = W * 3; + std::vector rgb_buf(static_cast(H * stride)); + int64_t pts = 0; + + for (int64_t t = 0; t < T; ++t) { + for (int64_t y = 0; y < H; ++y) { + for (int64_t x = 0; x < W; ++x) { + rgb_buf[static_cast(y * stride + x * 3 + 0)] = + static_cast( + std::clamp(video_acc[t][0][y][x] * 255.0f, 0.0f, 255.0f)); + rgb_buf[static_cast(y * stride + x * 3 + 1)] = + static_cast( + std::clamp(video_acc[t][1][y][x] * 255.0f, 0.0f, 255.0f)); + rgb_buf[static_cast(y * stride + x * 3 + 2)] = + static_cast( + std::clamp(video_acc[t][2][y][x] * 255.0f, 0.0f, 255.0f)); + } + } + + const uint8_t* src_data[1] = {rgb_buf.data()}; + int32_t src_linesize[1] = {static_cast(stride)}; + + if (av_frame_make_writable(yuv_frame_) < 0) { + return false; + } + sws_scale(sws_ctx_, + src_data, + src_linesize, + 0, + static_cast(H), + yuv_frame_->data, + yuv_frame_->linesize); + yuv_frame_->pts = pts++; + + if (!send_frame(yuv_frame_)) { + return false; + } + } + + if (!finish()) { + return false; + } + + auto out = take_output(); + raw_data.assign(out.begin(), out.end()); + + LOG(INFO) << "MemoryVideoWriter: encoded " << T << " frames (" << W << "x" + << H << ") at " << fps << " fps [" << format << "], output " + << out.size() << " bytes"; + return true; + } + + private: + SwsContext* sws_ctx_ = nullptr; + AVFrame* yuv_frame_ = nullptr; +}; + +bool FFmpegVideoEncoder::encode(const torch::Tensor& video, + double fps, + const std::string& format, + std::string& raw_data) { + MemoryVideoWriter writer; + return writer.write(video, fps, format, raw_data); +} + } // namespace xllm diff --git a/xllm/core/framework/request/mm_codec.h b/xllm/core/framework/request/mm_codec.h index accf442891..2f83a43b62 100644 --- a/xllm/core/framework/request/mm_codec.h +++ b/xllm/core/framework/request/mm_codec.h @@ -63,4 +63,17 @@ class FFmpegAudioDecoder { AudioMetadata& meta, int64_t target_sr = 16000); }; + +class FFmpegVideoEncoder final { + public: + FFmpegVideoEncoder() = default; + ~FFmpegVideoEncoder() = default; + + // Encode video tensor [T, C, H, W] (float32, RGB, 0-1 range) with explicit + // container format ("mp4", "avi", etc.). + bool encode(const torch::Tensor& video, + double fps, + const std::string& format, + std::string& raw_data); +}; } // namespace xllm diff --git a/xllm/core/runtime/dit_forward_params.h b/xllm/core/runtime/dit_forward_params.h index c8193b9044..3dc07aa6f9 100644 --- a/xllm/core/runtime/dit_forward_params.h +++ b/xllm/core/runtime/dit_forward_params.h @@ -29,7 +29,9 @@ namespace xllm { struct DiTForwardInput { bool valid() const { return prompts.size() > 0 || prompt_embeds.defined() || - pooled_prompt_embeds.defined() || images.defined(); + pooled_prompt_embeds.defined() || images.defined() || + last_images.defined() || image_embeds.defined() || + condition_images.defined(); } void save_with_prefix(std::string prefix) const { @@ -107,6 +109,13 @@ struct DiTForwardInput { os << "undefined" << std::endl; } + os << "condition_images: "; + if (condition_images.defined()) { + os << condition_images.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + os << "masked_image_latents: "; if (masked_image_latents.defined()) { os << masked_image_latents.sizes() << std::endl; @@ -149,6 +158,20 @@ struct DiTForwardInput { os << "undefined" << std::endl; } + os << "last_images: "; + if (last_images.defined()) { + os << last_images.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + + os << "image_embeds: "; + if (image_embeds.defined()) { + os << image_embeds.sizes() << std::endl; + } else { + os << "undefined" << std::endl; + } + // Print generation_params os << "\n--- Generation Parameters ---" << std::endl; os << "width: " << generation_params.width << std::endl; @@ -212,10 +235,19 @@ struct DiTForwardInput { input.control_image = control_image.to(device, dtype); } + if (condition_images.defined()) { + input.condition_images = condition_images.to(device, dtype); + } + + if (last_images.defined()) { + input.last_images = last_images.to(device, dtype); + } + if (image_embeds.defined()) { + input.image_embeds = image_embeds.to(device, dtype); + } if (prompt_audio.defined()) { input.prompt_audio = prompt_audio.to(device, torch::kFloat32); } - return input; } @@ -241,6 +273,8 @@ struct DiTForwardInput { torch::Tensor control_image; + torch::Tensor condition_images; + torch::Tensor masked_image_latents; torch::Tensor prompt_embeds; @@ -253,6 +287,12 @@ struct DiTForwardInput { torch::Tensor latents; + // Last images for video generation + torch::Tensor last_images; + + // Image embeddings for video generation + torch::Tensor image_embeds; + // Optional prompt audio for voice cloning (LongCat-AudioDiT) // Shape: (batch, 1, num_samples) at 24kHz torch::Tensor prompt_audio; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 8ba557baf9..1abce239af 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -328,7 +328,14 @@ inline size_t get_dit_generation_params_size( 4 // true_cfg_scale, guidance_scale, strength, cfg_renorm_min + type_size // num_images_per_prompt + type_size // seed - + type_size; // enable_cfg_renorm + + type_size // enable_cfg_renorm + + type_size // num_frames + + type_size // force_video_output + + type_size // video_fps + + type_size * // guidance_scale_2, boundary_ratio, flow_shift + 3 + + type_size // seconds + + type_size; // num_videos_per_prompt } inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { @@ -351,6 +358,8 @@ inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { size += get_tensor_size(input.negative_prompt_embeds); size += get_tensor_size(input.negative_pooled_prompt_embeds); size += get_tensor_size(input.latents); + size += get_tensor_size(input.last_images); + size += get_tensor_size(input.image_embeds); // Generation params size += get_dit_generation_params_size(input.generation_params); @@ -975,6 +984,14 @@ inline void write_dit_generation_params(char*& buffer, write_data(buffer, params.strength); write_data(buffer, params.enable_cfg_renorm); write_data(buffer, params.cfg_renorm_min); + write_data(buffer, params.num_frames); + write_data(buffer, params.force_video_output); + write_data(buffer, params.video_fps); + write_data(buffer, params.guidance_scale_2); + write_data(buffer, params.seconds); + write_data(buffer, params.boundary_ratio); + write_data(buffer, params.flow_shift); + write_data(buffer, params.num_videos_per_prompt); } inline void write_dit_generation_params(RawInputSerializeContext& context, @@ -990,6 +1007,14 @@ inline void write_dit_generation_params(RawInputSerializeContext& context, write_data(context.descriptor, params.strength); write_data(context.descriptor, params.enable_cfg_renorm); write_data(context.descriptor, params.cfg_renorm_min); + write_data(context.descriptor, params.num_frames); + write_data(context.descriptor, params.force_video_output); + write_data(context.descriptor, params.video_fps); + write_data(context.descriptor, params.guidance_scale_2); + write_data(context.descriptor, params.seconds); + write_data(context.descriptor, params.boundary_ratio); + write_data(context.descriptor, params.flow_shift); + write_data(context.descriptor, params.num_videos_per_prompt); } inline void write_dit_forward_input(char*& buffer, @@ -1011,6 +1036,8 @@ inline void write_dit_forward_input(char*& buffer, write_tensor(buffer, input.negative_prompt_embeds); write_tensor(buffer, input.negative_pooled_prompt_embeds); write_tensor(buffer, input.latents); + write_tensor(buffer, input.last_images); + write_tensor(buffer, input.image_embeds); write_dit_generation_params(buffer, input.generation_params); } @@ -1847,6 +1874,14 @@ inline void read_dit_generation_params(const char*& buffer, read_data(buffer, params.strength); read_data(buffer, params.enable_cfg_renorm); read_data(buffer, params.cfg_renorm_min); + read_data(buffer, params.num_frames); + read_data(buffer, params.force_video_output); + read_data(buffer, params.video_fps); + read_data(buffer, params.guidance_scale_2); + read_data(buffer, params.seconds); + read_data(buffer, params.boundary_ratio); + read_data(buffer, params.flow_shift); + read_data(buffer, params.num_videos_per_prompt); } inline void read_dit_generation_params(ReadContext& context, @@ -1862,6 +1897,14 @@ inline void read_dit_generation_params(ReadContext& context, read_data(context, params.strength); read_data(context, params.enable_cfg_renorm); read_data(context, params.cfg_renorm_min); + read_data(context, params.num_frames); + read_data(context, params.force_video_output); + read_data(context, params.video_fps); + read_data(context, params.guidance_scale_2); + read_data(context, params.seconds); + read_data(context, params.boundary_ratio); + read_data(context, params.flow_shift); + read_data(context, params.num_videos_per_prompt); } inline void read_dit_forward_input(const char*& buffer, @@ -1883,6 +1926,8 @@ inline void read_dit_forward_input(const char*& buffer, read_tensor(buffer, input.negative_prompt_embeds); read_tensor(buffer, input.negative_pooled_prompt_embeds); read_tensor(buffer, input.latents); + read_tensor(buffer, input.last_images); + read_tensor(buffer, input.image_embeds); read_dit_generation_params(buffer, input.generation_params); } diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index a872143d54..5d1cdb579a 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -370,6 +370,10 @@ bool dit_forward_input_to_proto(const DiTForwardInput& dit_inputs, torch_tensor_to_proto_tensor(dit_inputs.latents, pb_dit_inputs->mutable_latents()); + torch_tensor_to_proto_tensor(dit_inputs.last_images, + pb_dit_inputs->mutable_last_images()); + torch_tensor_to_proto_tensor(dit_inputs.image_embeds, + pb_dit_inputs->mutable_image_embeds()); torch_tensor_to_proto_tensor(dit_inputs.prompt_audio, pb_dit_inputs->mutable_prompt_audio()); @@ -408,6 +412,18 @@ bool generation_params_to_proto( dit_generation_params.enable_cfg_renorm); pb_dit_generation_params->set_cfg_renorm_min( dit_generation_params.cfg_renorm_min); + pb_dit_generation_params->set_num_frames(dit_generation_params.num_frames); + pb_dit_generation_params->set_force_video_output( + dit_generation_params.force_video_output); + pb_dit_generation_params->set_video_fps(dit_generation_params.video_fps); + pb_dit_generation_params->set_guidance_scale_2( + dit_generation_params.guidance_scale_2); + pb_dit_generation_params->set_seconds(dit_generation_params.seconds); + pb_dit_generation_params->set_boundary_ratio( + dit_generation_params.boundary_ratio); + pb_dit_generation_params->set_flow_shift(dit_generation_params.flow_shift); + pb_dit_generation_params->set_num_videos_per_prompt( + dit_generation_params.num_videos_per_prompt); return true; } @@ -482,6 +498,13 @@ bool proto_to_dit_forward_input(const proto::DiTForwardInput& pb_dit_inputs, if (pb_dit_inputs.has_latents()) { dit_inputs.latents = util::proto_to_torch(pb_dit_inputs.latents()); } + if (pb_dit_inputs.has_last_images()) { + dit_inputs.last_images = util::proto_to_torch(pb_dit_inputs.last_images()); + } + if (pb_dit_inputs.has_image_embeds()) { + dit_inputs.image_embeds = + util::proto_to_torch(pb_dit_inputs.image_embeds()); + } if (!proto_to_generation_params(pb_dit_inputs.generation_params(), dit_inputs.generation_params)) { @@ -522,6 +545,18 @@ bool proto_to_generation_params( pb_dit_generation_params.enable_cfg_renorm(); dit_generation_params.cfg_renorm_min = pb_dit_generation_params.cfg_renorm_min(); + dit_generation_params.num_frames = pb_dit_generation_params.num_frames(); + dit_generation_params.force_video_output = + pb_dit_generation_params.force_video_output(); + dit_generation_params.video_fps = pb_dit_generation_params.video_fps(); + dit_generation_params.guidance_scale_2 = + pb_dit_generation_params.guidance_scale_2(); + dit_generation_params.seconds = pb_dit_generation_params.seconds(); + dit_generation_params.boundary_ratio = + pb_dit_generation_params.boundary_ratio(); + dit_generation_params.flow_shift = pb_dit_generation_params.flow_shift(); + dit_generation_params.num_videos_per_prompt = + pb_dit_generation_params.num_videos_per_prompt(); return true; } diff --git a/xllm/proto/CMakeLists.txt b/xllm/proto/CMakeLists.txt index a69ebcf1d8..bb95ddf3b1 100644 --- a/xllm/proto/CMakeLists.txt +++ b/xllm/proto/CMakeLists.txt @@ -20,6 +20,7 @@ proto_library( xservice.proto image_generation.proto audio_generation.proto + video_generation.proto mooncake_transfer_engine.proto embedding_data.proto anthropic.proto diff --git a/xllm/proto/image_generation.proto b/xllm/proto/image_generation.proto index 18ef24e219..bfdb40947a 100644 --- a/xllm/proto/image_generation.proto +++ b/xllm/proto/image_generation.proto @@ -49,6 +49,9 @@ message Input { // multiple image repeated string images = 14; + + // Condition Image + optional string condition_image = 15; } // Generation parameters container diff --git a/xllm/proto/video_generation.proto b/xllm/proto/video_generation.proto new file mode 100644 index 0000000000..66d2bab996 --- /dev/null +++ b/xllm/proto/video_generation.proto @@ -0,0 +1,133 @@ +syntax = "proto3"; + +option go_package = "jd.com/jd-infer/xllm;xllm"; +package xllm.proto; + +import "common.proto"; +import "tensor.proto"; + +// Input parameters container +message VideoInput { + // Primary input text description for video generation + string prompt = 1; + + // Negative prompt to exclude unwanted features + optional string negative_prompt = 2; + + // prompt embedding + optional Tensor prompt_embed = 3; + + // negative prompt embedding + optional Tensor negative_prompt_embed = 4; + + // Input image for image-to-video generation (base64 encoded) + optional string image = 5; + + // Last image (base64 encoded) + optional string last_image = 6; + + // Image embeddings for image-to-video generation + optional Tensor image_embeds = 7; +} + +// Generation parameters container +message VideoParameters { + // Size of the generated video frames, e.g. "256*256" or "832*480" + optional string size = 1; + + // Number of inference steps for video generation + optional int32 num_inference_steps = 2; + + // Guidance scale value for prompt adherence + optional float guidance_scale = 3; + + // Guidance scale value for negative prompt adherence + optional float guidance_scale_2 = 4; + + // Number of videos to generate per prompt + optional int32 num_videos_per_prompt = 5; + + // Random seed value for generation + optional int64 seed = 6; + + // Maximum sequence length for prompt processing + optional int32 max_sequence_length = 7; + + // Number of frames to generate + optional int32 num_frames = 8; + + // Output video FPS (frames per second) + optional double fps = 9; + + // video duration + optional int32 seconds = 10; + + // high-low noise separation ratio + optional float boundary_ratio = 11; + + // scheduler flow shif + optional float flow_shift = 12; + + // True CFG scale value for balancing generation + optional float true_cfg_scale = 13; +} + +// Request structure for video generation +message VideoGenerationRequest { + // Model ID to use for generation + string model = 1; + + VideoInput input = 2; + + VideoParameters parameters = 3; + + // End-user identifier + optional string user = 4; + + // Server-side request ID for tracking + optional string request_id = 5; +} + +// Individual video generation result +message VideoGenData { + // Base64-encoded video data (MP4 format) + optional string video = 1; + + // Width of the generated video in pixels + int32 width = 3; + + // Height of the generated video in pixels + int32 height = 4; + + // Seed used for generation + int64 seed = 5; + + // Number of frames + int32 num_frames = 6; + + // FPS of the output video + double fps = 7; +} + +// Output container for video generation results +message VideoGenerationOutput { + repeated VideoGenData results = 1; +} + +// Response structure for video generation requests +message VideoGenerationResponse { + // Response ID + string id = 1; + + // Object type + string object = 2; + + // Unix timestamp of when the response was created + int64 created = 3; + + // Model used for generation + string model = 4; + + // Contains generation results + VideoGenerationOutput output = 5; +} diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 506f03339d..1393d21679 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -242,13 +242,18 @@ message DiTForwardInput { Tensor negative_prompt_embeds = 13; Tensor negative_pooled_prompt_embeds = 14; Tensor latents = 15; - - // generation params - DiTGenerationParams generation_params = 16; + Tensor last_images = 17; + Tensor image_embeds = 18; // Voice cloning fields (LongCat-AudioDiT) - optional Tensor prompt_audio = 17; - optional string audio_prompt_text = 18; + optional Tensor prompt_audio = 19; + optional string audio_prompt_text = 20; + + // Conditional image generation + Tensor condition_images = 21; + + // generation params + DiTGenerationParams generation_params = 16; } message DiTForwardOutput { @@ -267,6 +272,14 @@ message DiTGenerationParams { float strength = 9; bool enable_cfg_renorm = 10; float cfg_renorm_min = 11; + int32 num_frames = 12; + bool force_video_output = 13; + double video_fps = 14; + float guidance_scale_2 = 15; + int32 seconds = 16; + float boundary_ratio = 17; + float flow_shift = 18; + uint32 num_videos_per_prompt = 19; } message PackedForwardInput { diff --git a/xllm/proto/xllm_service.proto b/xllm/proto/xllm_service.proto index e2c6ce7c8e..255babd812 100644 --- a/xllm/proto/xllm_service.proto +++ b/xllm/proto/xllm_service.proto @@ -11,6 +11,7 @@ import "chat.proto"; import "embedding.proto"; import "image_generation.proto"; import "audio_generation.proto"; +import "video_generation.proto"; import "rerank.proto"; import "models.proto"; import "anthropic.proto"; @@ -68,6 +69,9 @@ service XllmAPIService { rpc AudioGeneration(AudioGenerationRequest) returns (AudioGenerationResponse); rpc AudioGenerationHttp(HttpRequest) returns (HttpResponse); + + rpc VideoGeneration(VideoGenerationRequest) returns (VideoGenerationResponse); + rpc VideoGenerationHttp(HttpRequest) returns (HttpResponse); rpc Rerank (RerankRequest) returns (RerankResponse); rpc RerankHttp (HttpRequest) returns (HttpResponse); diff --git a/xllm/server/xllm_server.cpp b/xllm/server/xllm_server.cpp index 1e6c7a38c1..5fdc2a95e7 100644 --- a/xllm/server/xllm_server.cpp +++ b/xllm/server/xllm_server.cpp @@ -44,6 +44,7 @@ constexpr const char* kApiServiceRoutes = "v1/models => ModelsHttp," "v1/image/generation => ImageGenerationHttp," "v1/audio/generation => AudioGenerationHttp," + "v1/video/generation => VideoGenerationHttp," "v1/rerank => RerankHttp," "v1/messages => AnthropicMessagesHttp," "v2/repository/index => ModelVersionsHttp,"