diff --git a/xllm/api_service/image_generation_service_impl.cpp b/xllm/api_service/image_generation_service_impl.cpp index 515b1804d..eda1358c9 100644 --- a/xllm/api_service/image_generation_service_impl.cpp +++ b/xllm/api_service/image_generation_service_impl.cpp @@ -79,6 +79,14 @@ void ImageGenerationServiceImpl::process_async_impl( return; } + // Check if the request is being rate-limited. + if (master_->get_rate_limiter()->is_limited()) { + call->finish_with_error( + StatusCode::RESOURCE_EXHAUSTED, + "The number of concurrent requests has reached the limit."); + return; + } + // create DiTRequestParams for image generation request DiTRequestParams request_params( rpc_request, call->get_x_request_id(), call->get_x_request_time()); @@ -90,9 +98,11 @@ void ImageGenerationServiceImpl::process_async_impl( call.get(), [call, model, + master = master_, request_id = std::move(saved_request_id), created_time = absl::ToUnixSeconds(absl::Now())]( const DiTRequestOutput& req_output) -> bool { + master->get_rate_limiter()->decrease_one_request(); if (req_output.status.has_value()) { const auto& status = req_output.status.value(); if (!status.ok()) { diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 0a2e23e47..322f8d8ba 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -331,6 +331,8 @@ DECLARE_int64(dit_cache_end_blocks); DECLARE_int64(dit_sp_communication_overlap); +DECLARE_int64(dit_generation_image_area_max); + DECLARE_bool(dit_debug_print); DECLARE_bool(use_audio_in_video); diff --git a/xllm/core/framework/batch/dit_batch.cpp b/xllm/core/framework/batch/dit_batch.cpp index abd4e8d6d..67a373fcf 100644 --- a/xllm/core/framework/batch/dit_batch.cpp +++ b/xllm/core/framework/batch/dit_batch.cpp @@ -61,11 +61,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() { std::vector negative_pooled_prompt_embeds; std::vector images; - std::vector condition_images; std::vector mask_images; std::vector control_images; std::vector latents; std::vector masked_image_latents; + const auto batch_size = request_vec_.size(); prompt_embeds.reserve(batch_size); pooled_prompt_embeds.reserve(batch_size); @@ -76,6 +76,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() { control_images.reserve(batch_size); latents.reserve(batch_size); masked_image_latents.reserve(batch_size); + + std::vector images_list; + size_t images_size = request_vec_[0]->state().input_params().images.size(); + bool images_size_valid = images_size > 0; + for (const auto& request : request_vec_) { const auto& generation_params = request->state().generation_params(); if (input.generation_params != generation_params) { @@ -107,9 +112,12 @@ DiTForwardInput DiTBatch::prepare_forward_input() { images.emplace_back(input_params.image); mask_images.emplace_back(input_params.mask_image); - condition_images.emplace_back(input_params.condition_image); control_images.emplace_back(input_params.control_image); + if (input_params.images.size() != images_size) { + images_size_valid = false; + } + // Voice cloning: prompt_audio is per-request (batch_size==1 in practice). // Forward the first defined tensor; multi-batch voice cloning is not // supported (different prompt lengths can't be stacked). @@ -142,8 +150,26 @@ DiTForwardInput DiTBatch::prepare_forward_input() { input.images = torch::stack(images); } - if (check_tensors_valid(condition_images)) { - input.condition_images = torch::stack(condition_images); + if (images_size_valid) { + images_list.reserve(images_size); + std::vector vec; + vec.reserve(request_vec_.size()); + + bool all_valid = true; + for (size_t idx = 0; idx < images_size; ++idx) { + vec.clear(); + for (const auto& req : request_vec_) { + vec.emplace_back(req->state().input_params().images[idx]); + } + if (!check_tensors_valid(vec)) { + all_valid = false; + break; + } + images_list.emplace_back(torch::stack(vec)); + } + if (all_valid) { + input.images_list = std::move(images_list); + } } if (check_tensors_valid(mask_images)) { diff --git a/xllm/core/framework/config/dit_config.cpp b/xllm/core/framework/config/dit_config.cpp index c071b28aa..3f3321edf 100644 --- a/xllm/core/framework/config/dit_config.cpp +++ b/xllm/core/framework/config/dit_config.cpp @@ -61,6 +61,11 @@ DEFINE_bool(dit_debug_print, false, "whether print the debug info for dit models"); +DEFINE_int64(dit_generation_image_area_max, + 0, + "Maximum allowed image area (width * height) for image generation " + "requests. If set to 0, there is no limit."); + namespace xllm { void DiTConfig::from_flags() { @@ -76,7 +81,8 @@ void DiTConfig::from_flags() { .dit_cache_start_blocks(FLAGS_dit_cache_start_blocks) .dit_cache_end_blocks(FLAGS_dit_cache_end_blocks) .dit_sp_communication_overlap(FLAGS_dit_sp_communication_overlap) - .dit_debug_print(FLAGS_dit_debug_print); + .dit_debug_print(FLAGS_dit_debug_print) + .dit_generation_image_area_max(FLAGS_dit_generation_image_area_max); } void DiTConfig::from_json(const JsonReader& json) { @@ -104,7 +110,9 @@ void DiTConfig::from_json(const JsonReader& json) { .dit_sp_communication_overlap(json.value_or( "dit_sp_communication_overlap", dit_sp_communication_overlap())) .dit_debug_print( - json.value_or("dit_debug_print", dit_debug_print())); + json.value_or("dit_debug_print", dit_debug_print())) + .dit_generation_image_area_max(json.value_or( + "dit_generation_image_area_max", dit_generation_image_area_max())); } DiTConfig& DiTConfig::get_instance() { diff --git a/xllm/core/framework/config/dit_config.h b/xllm/core/framework/config/dit_config.h index af30f5d81..f495e73ec 100644 --- a/xllm/core/framework/config/dit_config.h +++ b/xllm/core/framework/config/dit_config.h @@ -50,7 +50,8 @@ class DiTConfig final { "dit_cache_start_blocks", "dit_cache_end_blocks", "dit_sp_communication_overlap", - "dit_debug_print"}}; + "dit_debug_print" + "dit_generation_image_area_max"}}; return kOptionCategory; } @@ -77,6 +78,8 @@ class DiTConfig final { PROPERTY(int64_t, dit_sp_communication_overlap) = 1; PROPERTY(bool, dit_debug_print) = false; + + PROPERTY(int64_t, dit_generation_image_area_max) = 0; }; } // namespace xllm diff --git a/xllm/core/framework/request/dit_request_params.cpp b/xllm/core/framework/request/dit_request_params.cpp index 1c7f32adb..73c7b0fa2 100644 --- a/xllm/core/framework/request/dit_request_params.cpp +++ b/xllm/core/framework/request/dit_request_params.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "butil/base64.h" #include "core/common/instance_name.h" #include "core/common/macros.h" +#include "core/framework/config/dit_config.h" #include "core/util/utils.h" #include "core/util/uuid.h" #include "mm_codec.h" @@ -139,14 +140,19 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request, } } - if (input.has_condition_image()) { - std::string raw_bytes; - if (!butil::Base64Decode(input.condition_image(), &raw_bytes)) { + input_params.images.reserve(input.images().size()); + for (const auto& image : input.images()) { + std::string binary; + if (!butil::Base64Decode(image, &binary)) { LOG(ERROR) << "Base64 image decode failed"; + continue; } - if (!decoder.decode(raw_bytes, input_params.condition_image)) { + torch::Tensor tensor; + if (!decoder.decode(binary, tensor)) { LOG(ERROR) << "Image decode failed."; + continue; } + input_params.images.emplace_back(std::move(tensor)); } if (input.has_mask_image()) { @@ -258,11 +264,35 @@ DiTRequestParams::DiTRequestParams(const proto::AudioGenerationRequest& request, bool DiTRequestParams::verify_params( std::function callback) const { - if (input_params.prompt.empty()) { + if (input_params.prompt.empty() && !input_params.prompt_embed.defined()) { CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "prompt is empty"); return false; } + if (generation_params.width < 0 || generation_params.height < 0) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Invalid image dimensions: width and height must be non-negative."); + return false; + } + + // Check if the image area exceeds the maximum allowed area. + if (::xllm::DiTConfig::get_instance().dit_generation_image_area_max() > 0) { + int64_t area = static_cast(generation_params.width) * + static_cast(generation_params.height); + if (area > + ::xllm::DiTConfig::get_instance().dit_generation_image_area_max()) { + CALLBACK_WITH_ERROR( + StatusCode::INVALID_ARGUMENT, + "Requested image area (" + std::to_string(area) + + ") exceeds the maximum allowed area (" + + std::to_string(::xllm::DiTConfig::get_instance() + .dit_generation_image_area_max()) + + ")."); + return false; + } + } + return true; } diff --git a/xllm/core/framework/request/dit_request_state.h b/xllm/core/framework/request/dit_request_state.h index 3906d0ae2..179660565 100644 --- a/xllm/core/framework/request/dit_request_state.h +++ b/xllm/core/framework/request/dit_request_state.h @@ -115,7 +115,7 @@ struct DiTInputParams { torch::Tensor image; - torch::Tensor condition_image; + std::vector images; torch::Tensor control_image; diff --git a/xllm/core/framework/request/request.h b/xllm/core/framework/request/request.h index 44a7f3955..6dbe73e0c 100644 --- a/xllm/core/framework/request/request.h +++ b/xllm/core/framework/request/request.h @@ -57,11 +57,6 @@ class Request : public RequestBase { bool cancelled() const { return cancelled_.load(std::memory_order_relaxed); } - // Get the elapsed time since the request was created. - double elapsed_seconds() const { - return absl::ToDoubleSeconds(absl::Now() - created_time_); - } - RequestOutput generate_output(const Tokenizer& tokenizer, ThreadPool* thread_pool = nullptr); diff --git a/xllm/core/framework/request/request_base.h b/xllm/core/framework/request/request_base.h index 2ee32bae3..31e18bee8 100644 --- a/xllm/core/framework/request/request_base.h +++ b/xllm/core/framework/request/request_base.h @@ -47,6 +47,11 @@ class RequestBase { absl::Time created_time() const { return created_time_; } + // Get the elapsed time since the request was created. + double elapsed_seconds() const { + return absl::ToDoubleSeconds(absl::Now() - created_time_); + } + const std::string& request_id() const { return request_id_; } const std::string& service_request_id() const { return service_request_id_; } diff --git a/xllm/core/runtime/dit_forward_params.h b/xllm/core/runtime/dit_forward_params.h index 594dbbb2e..c8193b904 100644 --- a/xllm/core/runtime/dit_forward_params.h +++ b/xllm/core/runtime/dit_forward_params.h @@ -82,12 +82,16 @@ struct DiTForwardInput { os << "undefined" << std::endl; } - os << "condition_images: "; - if (condition_images.defined()) { - os << condition_images.sizes() << std::endl; - } else { - os << "undefined" << std::endl; + os << "images_list: ["; + for (size_t i = 0; i < images_list.size(); ++i) { + if (images_list[i].defined()) { + os << images_list[i].sizes(); + } else { + os << "undefined"; + } + if (i < images_list.size() - 1) os << ", "; } + os << "]" << std::endl; os << "mask_images: "; if (mask_images.defined()) { @@ -200,8 +204,8 @@ struct DiTForwardInput { input.mask_images = mask_images.to(device, dtype); } - if (condition_images.defined()) { - input.condition_images = condition_images.to(device, dtype); + for (auto& img : input.images_list) { + img = img.to(device, dtype); } if (control_image.defined()) { @@ -231,7 +235,7 @@ struct DiTForwardInput { torch::Tensor images; - torch::Tensor condition_images; + std::vector images_list; torch::Tensor mask_images; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index b7a78fd48..8ba557baf 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -143,6 +143,15 @@ inline size_t get_tensor_size(const torch::Tensor& tensor) { return size; } +inline size_t get_vector_tensor_size( + const std::vector& tensor_vec) { + size_t size = type_size; // tensor_num + for (const auto& tensor : tensor_vec) { + size += get_tensor_size(tensor); + } + return size; +} + template inline size_t get_2d_vector_size(const std::vector>& vec2d) { size_t size = type_size; @@ -333,7 +342,7 @@ inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { // Tensors size += get_tensor_size(input.images); - size += get_tensor_size(input.condition_images); + size += get_vector_tensor_size(input.images_list); size += get_tensor_size(input.mask_images); size += get_tensor_size(input.control_image); size += get_tensor_size(input.masked_image_latents); @@ -350,11 +359,7 @@ inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { } inline size_t get_dit_forward_output_size(const DiTForwardOutput& output) { - size_t size = type_size; // vector size - for (const auto& tensor : output.tensors) { - size += get_tensor_size(tensor); - } - return size; + return get_vector_tensor_size(output.tensors); } template @@ -806,6 +811,15 @@ inline void write_vector_tensor(char*& buffer, } } +inline void write_vector_tensor(RawInputSerializeContext& context, + const std::vector& tensor_vec) { + int32_t tensor_num = tensor_vec.size(); + write_data(context.descriptor, tensor_num); + for (const auto& tensor : tensor_vec) { + write_tensor(context, tensor); + } +} + inline void write_mm_dict(char*& buffer, const MMDict& mm_dict) { // size size_t size = mm_dict.size(); @@ -988,7 +1002,7 @@ inline void write_dit_forward_input(char*& buffer, write_string_vector(buffer, input.negative_prompts_2); write_tensor(buffer, input.images); - write_tensor(buffer, input.condition_images); + write_vector_tensor(buffer, input.images_list); write_tensor(buffer, input.mask_images); write_tensor(buffer, input.control_image); write_tensor(buffer, input.masked_image_latents); @@ -1011,7 +1025,7 @@ inline void write_dit_forward_input(RawInputSerializeContext& context, write_string_vector(context.descriptor, input.negative_prompts_2); write_tensor(context, input.images); - write_tensor(context, input.condition_images); + write_vector_tensor(context, input.images_list); write_tensor(context, input.mask_images); write_tensor(context, input.control_image); write_tensor(context, input.masked_image_latents); @@ -1026,10 +1040,7 @@ inline void write_dit_forward_input(RawInputSerializeContext& context, inline void write_dit_forward_output(char*& buffer, const DiTForwardOutput& output) { - write_data(buffer, static_cast(output.tensors.size())); - for (const auto& tensor : output.tensors) { - write_tensor(buffer, tensor); - } + write_vector_tensor(buffer, output.tensors); } inline void safe_advance_buffer(const char*& buffer, size_t offset) { @@ -1623,6 +1634,18 @@ inline void read_vector_tensor(const char*& buffer, } } +inline void read_vector_tensor(ReadContext& context, + std::vector& tensor_vec, + Stream* stream = nullptr, + bool force_host_materialize = false) { + int32_t tensor_num; + read_data(context, tensor_num); + tensor_vec.resize(tensor_num); + for (size_t i = 0; i < tensor_num; ++i) { + read_tensor(context, tensor_vec[i], stream, force_host_materialize); + } +} + inline void read_mm_dict(const char*& buffer, MMDict& mm_dict, const char*& device_buffer) { @@ -1851,7 +1874,7 @@ inline void read_dit_forward_input(const char*& buffer, read_string_vector(buffer, input.negative_prompts_2); read_tensor(buffer, input.images); - read_tensor(buffer, input.condition_images); + read_vector_tensor(buffer, input.images_list); read_tensor(buffer, input.mask_images); read_tensor(buffer, input.control_image); read_tensor(buffer, input.masked_image_latents); @@ -1877,10 +1900,10 @@ inline void read_dit_forward_input(ReadContext& context, input.images, /*stream=*/nullptr, /*force_host_materialize=*/true); - read_tensor(context, - input.condition_images, - /*stream=*/nullptr, - /*force_host_materialize=*/true); + read_vector_tensor(context, + input.images_list, + /*stream=*/nullptr, + /*force_host_materialize=*/true); read_tensor(context, input.mask_images, /*stream=*/nullptr, @@ -1919,12 +1942,7 @@ inline void read_dit_forward_input(ReadContext& context, inline void read_dit_forward_output(const char*& buffer, DiTForwardOutput& output) { - uint64_t size; - read_data(buffer, size); - output.tensors.resize(size); - for (auto& tensor : output.tensors) { - read_tensor(buffer, tensor); - } + read_vector_tensor(buffer, output.tensors); } inline void initialize_device_buffer_session(ReadContext& context, diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index c7807c053..a872143d5 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -337,47 +337,45 @@ bool dit_forward_input_to_proto(const DiTForwardInput& dit_inputs, ADD_VECTOR_TO_PROTO(pb_dit_inputs->mutable_negative_prompts_2(), dit_inputs.negative_prompts_2); - if (dit_inputs.images.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.images, - pb_dit_inputs->mutable_images()); - } - if (dit_inputs.condition_images.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.condition_images, - pb_dit_inputs->mutable_condition_images()); - } - if (dit_inputs.mask_images.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.mask_images, - pb_dit_inputs->mutable_mask_images()); - } - if (dit_inputs.control_image.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.control_image, - pb_dit_inputs->mutable_control_image()); - } - if (dit_inputs.masked_image_latents.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.masked_image_latents, - pb_dit_inputs->mutable_masked_image_latents()); - } - if (dit_inputs.prompt_embeds.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.prompt_embeds, - pb_dit_inputs->mutable_prompt_embeds()); - } - if (dit_inputs.pooled_prompt_embeds.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.pooled_prompt_embeds, - pb_dit_inputs->mutable_pooled_prompt_embeds()); - } - if (dit_inputs.negative_prompt_embeds.defined()) { - torch_tensor_to_proto_tensor( - dit_inputs.negative_prompt_embeds, - pb_dit_inputs->mutable_negative_prompt_embeds()); - } - if (dit_inputs.negative_pooled_prompt_embeds.defined()) { - torch_tensor_to_proto_tensor( - dit_inputs.negative_pooled_prompt_embeds, - pb_dit_inputs->mutable_negative_pooled_prompt_embeds()); + torch_tensor_to_proto_tensor(dit_inputs.images, + pb_dit_inputs->mutable_images()); + + auto* pb_images_list = + pb_dit_inputs->mutable_images_list()->mutable_tensors(); + for (const auto& tensor : dit_inputs.images_list) { + torch_tensor_to_proto_tensor(tensor, pb_images_list->Add()); } - if (dit_inputs.latents.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.latents, - pb_dit_inputs->mutable_latents()); + + torch_tensor_to_proto_tensor(dit_inputs.mask_images, + pb_dit_inputs->mutable_mask_images()); + + torch_tensor_to_proto_tensor(dit_inputs.control_image, + pb_dit_inputs->mutable_control_image()); + + torch_tensor_to_proto_tensor(dit_inputs.masked_image_latents, + pb_dit_inputs->mutable_masked_image_latents()); + + torch_tensor_to_proto_tensor(dit_inputs.prompt_embeds, + pb_dit_inputs->mutable_prompt_embeds()); + + torch_tensor_to_proto_tensor(dit_inputs.pooled_prompt_embeds, + pb_dit_inputs->mutable_pooled_prompt_embeds()); + + torch_tensor_to_proto_tensor(dit_inputs.negative_prompt_embeds, + pb_dit_inputs->mutable_negative_prompt_embeds()); + + torch_tensor_to_proto_tensor( + dit_inputs.negative_pooled_prompt_embeds, + pb_dit_inputs->mutable_negative_pooled_prompt_embeds()); + + torch_tensor_to_proto_tensor(dit_inputs.latents, + pb_dit_inputs->mutable_latents()); + + torch_tensor_to_proto_tensor(dit_inputs.prompt_audio, + pb_dit_inputs->mutable_prompt_audio()); + + if (!dit_inputs.audio_prompt_text.empty()) { + pb_dit_inputs->set_audio_prompt_text(dit_inputs.audio_prompt_text); } if (!generation_params_to_proto(dit_inputs.generation_params, @@ -386,14 +384,6 @@ bool dit_forward_input_to_proto(const DiTForwardInput& dit_inputs, return false; } - if (dit_inputs.prompt_audio.defined()) { - torch_tensor_to_proto_tensor(dit_inputs.prompt_audio, - pb_dit_inputs->mutable_prompt_audio()); - } - if (!dit_inputs.audio_prompt_text.empty()) { - pb_dit_inputs->set_audio_prompt_text(dit_inputs.audio_prompt_text); - } - return true; } @@ -447,9 +437,12 @@ bool proto_to_dit_forward_input(const proto::DiTForwardInput& pb_dit_inputs, dit_inputs.images = util::proto_to_torch(pb_dit_inputs.images()); } - if (pb_dit_inputs.has_condition_images()) { - dit_inputs.condition_images = - util::proto_to_torch(pb_dit_inputs.condition_images()); + if (pb_dit_inputs.has_images_list()) { + dit_inputs.images_list.reserve( + pb_dit_inputs.images_list().tensors().size()); + for (const auto& pb_tensor : pb_dit_inputs.images_list().tensors()) { + dit_inputs.images_list.emplace_back(util::proto_to_torch(pb_tensor)); + } } if (pb_dit_inputs.has_mask_images()) { diff --git a/xllm/core/scheduler/dit_scheduler.cpp b/xllm/core/scheduler/dit_scheduler.cpp index 7a0553e60..d0ec61e50 100644 --- a/xllm/core/scheduler/dit_scheduler.cpp +++ b/xllm/core/scheduler/dit_scheduler.cpp @@ -38,8 +38,11 @@ constexpr size_t kRequestQueueSize = 100; void DiTAsyncResponseProcessor::process_completed_request( std::shared_ptr request) { response_threadpool_.schedule([request = std::move(request)]() { - LOG(INFO) << "request_id: " << request->request_id(); + double end_2_end_latency_seconds = request->elapsed_seconds(); + HISTOGRAM_OBSERVE(end_2_end_latency_milliseconds, + static_cast(end_2_end_latency_seconds * 1000.0)); + request->log_statistic(end_2_end_latency_seconds); request->state().output_func()(request->generate_output()); }); } diff --git a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h index 399fc4933..610f11767 100644 --- a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h +++ b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h @@ -277,9 +277,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { auto seed = generation_params.seed >= 0 ? generation_params.seed : 42; auto prompts = input.prompts; - auto prompts_2 = input.prompts_2; auto negative_prompts = input.negative_prompts; - auto negative_prompts_2 = input.negative_prompts_2; auto latents = input.latents; if (latents.defined()) { latents = latents.to(options_.device(), dtype_); @@ -289,7 +287,6 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { if (prompt_embeds.defined()) { prompt_embeds = prompt_embeds.to(options_.device(), dtype_); } - auto pooled_prompt_embeds = input.pooled_prompt_embeds; torch::Tensor prompt_embeds_mask; auto negative_prompt_embeds = input.negative_prompt_embeds; @@ -297,56 +294,41 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { negative_prompt_embeds = negative_prompt_embeds.to(options_.device(), dtype_); } - auto negative_pooled_prompt_embeds = input.negative_pooled_prompt_embeds; torch::Tensor negative_prompt_embeds_mask; - std::vector image_list; - - torch::Tensor images; - if (::xllm::DiTConfig::get_instance().dit_debug_print()) { input.debug_print(); } - if (input.images.defined()) { - images = input.images.to(options_.device(), dtype_); - if (input.images.dim() == 3) { - image_list.emplace_back(images); - } else if (input.images.dim() == 4) { - if (input.images.size(0) > 1) { - LOG(ERROR) << "currently dit models doesn't support batch inference" - << "batch size: " << input.images.size(0); - } - image_list.emplace_back(images[0]); - } else { - LOG(ERROR) - << "image inputs are expected to be a 4 dim tensor, but got: " - << input.images.dim() << "s tensor"; - } - } else { - LOG(ERROR) << "QwenImageEditPlus pipeline expected to have " - << "image inputs"; + if (input.images_list.empty()) { + LOG(FATAL) << "QwenImageEditPlus pipeline expected to have " + << "image inputs in images_list"; } - torch::Tensor conditional_images; - if (input.condition_images.defined()) { - conditional_images = input.condition_images.to(options_.device(), dtype_); - if (input.condition_images.dim() == 3) { - image_list.emplace_back(conditional_images); - } else if (input.condition_images.dim() == 4) { - if (input.condition_images.size(0) > 1) { - LOG(ERROR) << "currently dit models doesn't support batch inference" - << "batch size: " << input.condition_images.size(0); - } - image_list.emplace_back(conditional_images[0]); - } else { + std::vector image_list; + image_list.reserve(input.images_list.size()); + + for (const auto& images : input.images_list) { + auto img = images.to(options_.device(), dtype_); + if (img.dim() != 4) { LOG(ERROR) << "image inputs are expected to be a 4 dim tensor, but got: " - << input.condition_images.dim() << "s tensor"; + << img.dim() << "d tensor"; + continue; } + if (img.size(0) > 1) { + LOG(ERROR) << "currently QwenImageEdit doesn't support batch inference" + << "batch size: " << img.size(0); + } + image_list.emplace_back(img[0]); + } + + if (image_list.empty()) { + LOG(FATAL) << "No valid images found in images_list. "; } - double height_size = images.size(2); - double width_size = images.size(3); + + double height_size = static_cast(image_list[0].size(1)); + double width_size = static_cast(image_list[0].size(2)); int64_t num_images_per_prompt = 1; double aspect_ratio = width_size / height_size; @@ -367,7 +349,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { std::vector vae_images; std::vector> condition_image_sizes; std::vector> vae_image_sizes; - if (images.defined() && !(images.size(1) == latent_channels_)) { + if (!image_list.empty() && image_list[0].size(0) != latent_channels_) { for (size_t i = 0; i < image_list.size(); i++) { aspect_ratio = static_cast(image_list[i].size(2)) / image_list[i].size(1); @@ -392,7 +374,8 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { } } - bool has_neg_prompt = negative_prompts.size() > 0; + bool has_neg_prompt = + negative_prompts.size() > 0 || negative_prompt_embeds.defined(); bool do_true_cfg = (true_cfg_scale > 1.0) && has_neg_prompt; // inplace update prompt_embeds and prompt_embeds_mask diff --git a/xllm/proto/image_generation.proto b/xllm/proto/image_generation.proto index 859b40fa7..18ef24e21 100644 --- a/xllm/proto/image_generation.proto +++ b/xllm/proto/image_generation.proto @@ -47,8 +47,8 @@ message Input { // Control Image optional string control_image = 13; - // Condition Image - optional string condition_image = 14; + // multiple image + repeated string images = 14; } // Generation parameters container diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 345c9a130..49a1d0a7d 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -211,7 +211,7 @@ message DiTForwardInput { // Tensor fields Tensor images = 6; - Tensor condition_images = 7; + TensorList images_list = 7; Tensor mask_images = 8; Tensor control_image = 9; Tensor masked_image_latents = 10;