Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions xllm/api_service/image_generation_service_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ void ImageGenerationServiceImpl::process_async_impl(
return;
}

// Check if the request is being rate-limited.
if (master_->get_rate_limiter()->is_limited()) {
call->finish_with_error(
StatusCode::RESOURCE_EXHAUSTED,
"The number of concurrent requests has reached the limit.");
return;
}
Comment thread
xiao-yu-chen marked this conversation as resolved.

// create DiTRequestParams for image generation request
DiTRequestParams request_params(
rpc_request, call->get_x_request_id(), call->get_x_request_time());
Expand All @@ -90,9 +98,11 @@ void ImageGenerationServiceImpl::process_async_impl(
call.get(),
[call,
model,
master = master_,
request_id = std::move(saved_request_id),
created_time = absl::ToUnixSeconds(absl::Now())](
const DiTRequestOutput& req_output) -> bool {
master->get_rate_limiter()->decrease_one_request();
if (req_output.status.has_value()) {
const auto& status = req_output.status.value();
if (!status.ok()) {
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ DECLARE_int64(dit_cache_end_blocks);

DECLARE_int64(dit_sp_communication_overlap);

DECLARE_int64(dit_generation_image_area_max);

DECLARE_bool(dit_debug_print);

DECLARE_bool(use_audio_in_video);
Expand Down
34 changes: 30 additions & 4 deletions xllm/core/framework/batch/dit_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
std::vector<torch::Tensor> negative_pooled_prompt_embeds;

std::vector<torch::Tensor> images;
std::vector<torch::Tensor> condition_images;
std::vector<torch::Tensor> mask_images;
std::vector<torch::Tensor> control_images;
std::vector<torch::Tensor> latents;
std::vector<torch::Tensor> masked_image_latents;

const auto batch_size = request_vec_.size();
prompt_embeds.reserve(batch_size);
pooled_prompt_embeds.reserve(batch_size);
Expand All @@ -76,6 +76,11 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
control_images.reserve(batch_size);
latents.reserve(batch_size);
masked_image_latents.reserve(batch_size);

std::vector<torch::Tensor> images_list;
size_t images_size = request_vec_[0]->state().input_params().images.size();
bool images_size_valid = images_size > 0;

for (const auto& request : request_vec_) {
const auto& generation_params = request->state().generation_params();
if (input.generation_params != generation_params) {
Expand Down Expand Up @@ -107,9 +112,12 @@ DiTForwardInput DiTBatch::prepare_forward_input() {

images.emplace_back(input_params.image);
mask_images.emplace_back(input_params.mask_image);
condition_images.emplace_back(input_params.condition_image);
control_images.emplace_back(input_params.control_image);

if (input_params.images.size() != images_size) {
images_size_valid = false;
}

// Voice cloning: prompt_audio is per-request (batch_size==1 in practice).
// Forward the first defined tensor; multi-batch voice cloning is not
// supported (different prompt lengths can't be stacked).
Expand Down Expand Up @@ -142,8 +150,26 @@ DiTForwardInput DiTBatch::prepare_forward_input() {
input.images = torch::stack(images);
}

if (check_tensors_valid(condition_images)) {
input.condition_images = torch::stack(condition_images);
if (images_size_valid) {
images_list.reserve(images_size);
std::vector<torch::Tensor> vec;
vec.reserve(request_vec_.size());

bool all_valid = true;
for (size_t idx = 0; idx < images_size; ++idx) {
vec.clear();
for (const auto& req : request_vec_) {
vec.emplace_back(req->state().input_params().images[idx]);
}
if (!check_tensors_valid(vec)) {
all_valid = false;
break;
}
images_list.emplace_back(torch::stack(vec));
}
if (all_valid) {
input.images_list = std::move(images_list);
}
}

if (check_tensors_valid(mask_images)) {
Expand Down
12 changes: 10 additions & 2 deletions xllm/core/framework/config/dit_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ DEFINE_bool(dit_debug_print,
false,
"whether print the debug info for dit models");

DEFINE_int64(dit_generation_image_area_max,
0,
"Maximum allowed image area (width * height) for image generation "
"requests. If set to 0, there is no limit.");

namespace xllm {

void DiTConfig::from_flags() {
Expand All @@ -76,7 +81,8 @@ void DiTConfig::from_flags() {
.dit_cache_start_blocks(FLAGS_dit_cache_start_blocks)
.dit_cache_end_blocks(FLAGS_dit_cache_end_blocks)
.dit_sp_communication_overlap(FLAGS_dit_sp_communication_overlap)
.dit_debug_print(FLAGS_dit_debug_print);
.dit_debug_print(FLAGS_dit_debug_print)
.dit_generation_image_area_max(FLAGS_dit_generation_image_area_max);
}

void DiTConfig::from_json(const JsonReader& json) {
Expand Down Expand Up @@ -104,7 +110,9 @@ void DiTConfig::from_json(const JsonReader& json) {
.dit_sp_communication_overlap(json.value_or<int64_t>(
"dit_sp_communication_overlap", dit_sp_communication_overlap()))
.dit_debug_print(
json.value_or<bool>("dit_debug_print", dit_debug_print()));
json.value_or<bool>("dit_debug_print", dit_debug_print()))
.dit_generation_image_area_max(json.value_or<int64_t>(
"dit_generation_image_area_max", dit_generation_image_area_max()));
}

DiTConfig& DiTConfig::get_instance() {
Expand Down
5 changes: 4 additions & 1 deletion xllm/core/framework/config/dit_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class DiTConfig final {
"dit_cache_start_blocks",
"dit_cache_end_blocks",
"dit_sp_communication_overlap",
"dit_debug_print"}};
"dit_debug_print"
"dit_generation_image_area_max"}};
return kOptionCategory;
}

Expand All @@ -77,6 +78,8 @@ class DiTConfig final {
PROPERTY(int64_t, dit_sp_communication_overlap) = 1;

PROPERTY(bool, dit_debug_print) = false;

PROPERTY(int64_t, dit_generation_image_area_max) = 0;
};

} // namespace xllm
40 changes: 35 additions & 5 deletions xllm/core/framework/request/dit_request_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "butil/base64.h"
#include "core/common/instance_name.h"
#include "core/common/macros.h"
#include "core/framework/config/dit_config.h"
#include "core/util/utils.h"
#include "core/util/uuid.h"
#include "mm_codec.h"
Expand Down Expand Up @@ -139,14 +140,19 @@ DiTRequestParams::DiTRequestParams(const proto::ImageGenerationRequest& request,
}
}

if (input.has_condition_image()) {
std::string raw_bytes;
if (!butil::Base64Decode(input.condition_image(), &raw_bytes)) {
input_params.images.reserve(input.images().size());
for (const auto& image : input.images()) {
std::string binary;
if (!butil::Base64Decode(image, &binary)) {
LOG(ERROR) << "Base64 image decode failed";
continue;
}
if (!decoder.decode(raw_bytes, input_params.condition_image)) {
torch::Tensor tensor;
if (!decoder.decode(binary, tensor)) {
LOG(ERROR) << "Image decode failed.";
continue;
}
input_params.images.emplace_back(std::move(tensor));
}

if (input.has_mask_image()) {
Expand Down Expand Up @@ -258,11 +264,35 @@ DiTRequestParams::DiTRequestParams(const proto::AudioGenerationRequest& request,

bool DiTRequestParams::verify_params(
std::function<bool(DiTRequestOutput)> callback) const {
if (input_params.prompt.empty()) {
if (input_params.prompt.empty() && !input_params.prompt_embed.defined()) {
CALLBACK_WITH_ERROR(StatusCode::INVALID_ARGUMENT, "prompt is empty");
return false;
}

if (generation_params.width < 0 || generation_params.height < 0) {
CALLBACK_WITH_ERROR(
StatusCode::INVALID_ARGUMENT,
"Invalid image dimensions: width and height must be non-negative.");
return false;
}

// Check if the image area exceeds the maximum allowed area.
if (::xllm::DiTConfig::get_instance().dit_generation_image_area_max() > 0) {
int64_t area = static_cast<int64_t>(generation_params.width) *
static_cast<int64_t>(generation_params.height);
if (area >
::xllm::DiTConfig::get_instance().dit_generation_image_area_max()) {
CALLBACK_WITH_ERROR(
StatusCode::INVALID_ARGUMENT,
"Requested image area (" + std::to_string(area) +
") exceeds the maximum allowed area (" +
std::to_string(::xllm::DiTConfig::get_instance()
.dit_generation_image_area_max()) +
").");
return false;
}
}

return true;
}

Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/request/dit_request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ struct DiTInputParams {

torch::Tensor image;

torch::Tensor condition_image;
std::vector<torch::Tensor> images;

torch::Tensor control_image;

Expand Down
5 changes: 0 additions & 5 deletions xllm/core/framework/request/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ class Request : public RequestBase {

bool cancelled() const { return cancelled_.load(std::memory_order_relaxed); }

// Get the elapsed time since the request was created.
double elapsed_seconds() const {
return absl::ToDoubleSeconds(absl::Now() - created_time_);
}

RequestOutput generate_output(const Tokenizer& tokenizer,
ThreadPool* thread_pool = nullptr);

Expand Down
5 changes: 5 additions & 0 deletions xllm/core/framework/request/request_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class RequestBase {

absl::Time created_time() const { return created_time_; }

// Get the elapsed time since the request was created.
double elapsed_seconds() const {
return absl::ToDoubleSeconds(absl::Now() - created_time_);
}

const std::string& request_id() const { return request_id_; }

const std::string& service_request_id() const { return service_request_id_; }
Expand Down
20 changes: 12 additions & 8 deletions xllm/core/runtime/dit_forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ struct DiTForwardInput {
os << "undefined" << std::endl;
}

os << "condition_images: ";
if (condition_images.defined()) {
os << condition_images.sizes() << std::endl;
} else {
os << "undefined" << std::endl;
os << "images_list: [";
for (size_t i = 0; i < images_list.size(); ++i) {
if (images_list[i].defined()) {
os << images_list[i].sizes();
} else {
os << "undefined";
}
if (i < images_list.size() - 1) os << ", ";
}
os << "]" << std::endl;

os << "mask_images: ";
if (mask_images.defined()) {
Expand Down Expand Up @@ -200,8 +204,8 @@ struct DiTForwardInput {
input.mask_images = mask_images.to(device, dtype);
}

if (condition_images.defined()) {
input.condition_images = condition_images.to(device, dtype);
for (auto& img : input.images_list) {
img = img.to(device, dtype);
}

if (control_image.defined()) {
Expand Down Expand Up @@ -231,7 +235,7 @@ struct DiTForwardInput {

torch::Tensor images;

torch::Tensor condition_images;
std::vector<torch::Tensor> images_list;

torch::Tensor mask_images;

Expand Down
Loading
Loading