diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 0a2e23e47..ec159ba2b 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -97,6 +97,8 @@ DECLARE_int64(sp_size); DECLARE_int64(cfg_size); +DECLARE_int64(vae_size); + DECLARE_bool(enable_prefill_sp); DECLARE_bool(enable_multi_stream_parallel); @@ -333,6 +335,8 @@ DECLARE_int64(dit_sp_communication_overlap); DECLARE_bool(dit_debug_print); +DECLARE_bool(enable_dit_vae_tiling); + DECLARE_bool(use_audio_in_video); // --- kernel config --- diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 588de4556..1c1bc3fda 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -134,6 +134,8 @@ class Options { PROPERTY(int32_t, cfg_size) = 1; + PROPERTY(int32_t, vae_size) = 1; + PROPERTY(std::optional, instance_name); PROPERTY(bool, enable_disagg_pd) = false; diff --git a/xllm/core/distributed_runtime/dist_manager.cpp b/xllm/core/distributed_runtime/dist_manager.cpp index f7e242ff9..121cae4ce 100644 --- a/xllm/core/distributed_runtime/dist_manager.cpp +++ b/xllm/core/distributed_runtime/dist_manager.cpp @@ -188,12 +188,13 @@ void DistManager::setup_multi_node_workers( const int32_t tp_size = options.tp_size(); const int32_t sp_size = options.sp_size(); const int32_t cfg_size = options.cfg_size(); + const int32_t vae_size = options.vae_size(); LOG(INFO) << "Multi-node serving world_size = " << world_size << ", each_node_ranks = " << each_node_ranks << ", current node rank = " << options.node_rank() << ", nnodes = " << options.nnodes() << ", dp_size = " << dp_size << ", tp_size = " << tp_size << ", sp_size = " << sp_size - << ", cfg_size = " << cfg_size; + << ", cfg_size = " << cfg_size << ", vae_size = " << vae_size; } else { LOG(INFO) << "Multi-node serving world_size = " << world_size << ", each_node_ranks = " << each_node_ranks diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 1ece38b07..f76e6dd74 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -457,7 +457,8 @@ Master::Master(const Options& options, EngineType type) .ep_size(options_.ep_size()) .tp_size(options_.tp_size()) .sp_size(options_.sp_size()) - .cfg_size(options_.cfg_size()); + .cfg_size(options_.cfg_size()) + .vae_size(options_.vae_size()); auto dit_engine = std::make_unique(eng_options); engine_ = std::move(dit_engine); diff --git a/xllm/core/distributed_runtime/worker_server.cpp b/xllm/core/distributed_runtime/worker_server.cpp index 73091217e..07247158f 100644 --- a/xllm/core/distributed_runtime/worker_server.cpp +++ b/xllm/core/distributed_runtime/worker_server.cpp @@ -160,7 +160,8 @@ void WorkerServer::create_server( options.dp_size(), options.tp_size(), options.sp_size(), - options.cfg_size()); + options.cfg_size(), + options.vae_size()); comm = std::move(dit_comm); } else { auto common_comm = std::make_unique( diff --git a/xllm/core/framework/config/dit_config.cpp b/xllm/core/framework/config/dit_config.cpp index c071b28aa..d6df6b6b2 100644 --- a/xllm/core/framework/config/dit_config.cpp +++ b/xllm/core/framework/config/dit_config.cpp @@ -61,6 +61,13 @@ DEFINE_bool(dit_debug_print, false, "whether print the debug info for dit models"); +// --- dit vae tiling --- + +DEFINE_bool( + enable_dit_vae_tiling, + false, + "whether enable vae tiling, currently only support qwen-image-edit-plus"); + namespace xllm { void DiTConfig::from_flags() { @@ -76,7 +83,8 @@ void DiTConfig::from_flags() { .dit_cache_start_blocks(FLAGS_dit_cache_start_blocks) .dit_cache_end_blocks(FLAGS_dit_cache_end_blocks) .dit_sp_communication_overlap(FLAGS_dit_sp_communication_overlap) - .dit_debug_print(FLAGS_dit_debug_print); + .dit_debug_print(FLAGS_dit_debug_print) + .enable_dit_vae_tiling(FLAGS_enable_dit_vae_tiling); } void DiTConfig::from_json(const JsonReader& json) { @@ -104,7 +112,9 @@ void DiTConfig::from_json(const JsonReader& json) { .dit_sp_communication_overlap(json.value_or( "dit_sp_communication_overlap", dit_sp_communication_overlap())) .dit_debug_print( - json.value_or("dit_debug_print", dit_debug_print())); + json.value_or("dit_debug_print", dit_debug_print())) + .enable_dit_vae_tiling(json.value_or("enable_dit_vae_tiling", + enable_dit_vae_tiling())); } DiTConfig& DiTConfig::get_instance() { diff --git a/xllm/core/framework/config/dit_config.h b/xllm/core/framework/config/dit_config.h index af30f5d81..7557bd8c3 100644 --- a/xllm/core/framework/config/dit_config.h +++ b/xllm/core/framework/config/dit_config.h @@ -50,7 +50,8 @@ class DiTConfig final { "dit_cache_start_blocks", "dit_cache_end_blocks", "dit_sp_communication_overlap", - "dit_debug_print"}}; + "dit_debug_print", + "enable_dit_vae_tiling"}}; return kOptionCategory; } @@ -77,6 +78,8 @@ class DiTConfig final { PROPERTY(int64_t, dit_sp_communication_overlap) = 1; PROPERTY(bool, dit_debug_print) = false; + + PROPERTY(bool, enable_dit_vae_tiling) = false; }; } // namespace xllm diff --git a/xllm/core/framework/config/parallel_config.cpp b/xllm/core/framework/config/parallel_config.cpp index a214c6450..fd27b3435 100644 --- a/xllm/core/framework/config/parallel_config.cpp +++ b/xllm/core/framework/config/parallel_config.cpp @@ -33,6 +33,8 @@ DEFINE_int64(cfg_size, "Classifier-free guidiance parallelism size, only used for DiT " "model."); +DEFINE_int64(vae_size, 1, "Vae patch parallelism size"); + DEFINE_string( communication_backend, "hccl", @@ -67,6 +69,7 @@ void ParallelConfig::from_flags() { .tp_size(FLAGS_tp_size) .sp_size(FLAGS_sp_size) .cfg_size(FLAGS_cfg_size) + .vae_size(FLAGS_vae_size) .communication_backend(FLAGS_communication_backend) .enable_prefill_sp(FLAGS_enable_prefill_sp) .enable_multi_stream_parallel(FLAGS_enable_multi_stream_parallel) @@ -81,6 +84,7 @@ void ParallelConfig::from_json(const JsonReader& json) { .tp_size(json.value_or("tp_size", tp_size())) .sp_size(json.value_or("sp_size", sp_size())) .cfg_size(json.value_or("cfg_size", cfg_size())) + .vae_size(json.value_or("vae_size", vae_size())) .communication_backend(json.value_or( "communication_backend", communication_backend())) .enable_prefill_sp( diff --git a/xllm/core/framework/config/parallel_config.h b/xllm/core/framework/config/parallel_config.h index 6c1ea28c2..cd0bd56c4 100644 --- a/xllm/core/framework/config/parallel_config.h +++ b/xllm/core/framework/config/parallel_config.h @@ -45,6 +45,7 @@ class ParallelConfig final { "tp_size", "sp_size", "cfg_size", + "vae_size", "communication_backend", "enable_prefill_sp", "enable_multi_stream_parallel", @@ -65,6 +66,8 @@ class ParallelConfig final { PROPERTY(int64_t, cfg_size) = 1; + PROPERTY(int64_t, vae_size) = 1; + PROPERTY(std::string, communication_backend) = "hccl"; PROPERTY(bool, enable_prefill_sp) = false; diff --git a/xllm/core/framework/parallel_state/dit_collective_communicator.cpp b/xllm/core/framework/parallel_state/dit_collective_communicator.cpp index 4f3304f57..3a40a4900 100644 --- a/xllm/core/framework/parallel_state/dit_collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/dit_collective_communicator.cpp @@ -40,7 +40,8 @@ DiTCollectiveCommunicator::DiTCollectiveCommunicator(int32_t global_rank, int32_t dit_dp_size, int32_t dit_tp_size, int32_t dit_sp_size, - int32_t dit_cfg_size) + int32_t dit_cfg_size, + int32_t dit_vae_size) : CollectiveCommunicatorBase(global_rank, world_size) { parallel_args_ = std::make_unique(global_rank, world_size, @@ -48,12 +49,14 @@ DiTCollectiveCommunicator::DiTCollectiveCommunicator(int32_t global_rank, dit_tp_size, dit_sp_size, dit_cfg_size, + dit_vae_size, /*process_group=*/nullptr); DiTMapping::Options dit_mapping_options; dit_mapping_options.dit_tp_size(dit_tp_size) .dit_sp_size(dit_sp_size) .dit_cfg_size(dit_cfg_size) - .dit_dp_size(dit_dp_size); + .dit_dp_size(dit_dp_size) + .dit_vae_size(dit_vae_size); dit_mapping_ = std::make_unique( world_size, global_rank, dit_mapping_options); } @@ -63,114 +66,75 @@ void DiTCollectiveCommunicator::create_process_groups( const torch::Device& device) { Device device_(device); device_.set_device(); - std::string host; - int32_t port; - net::parse_host_port_from_addr(master_addr, host, port); - - int32_t global_rank = parallel_args_->rank(); - int32_t world_size = parallel_args_->world_size(); - int32_t dp_size = parallel_args_->dp_size(); - int32_t tp_size = parallel_args_->tp_size(); - int32_t sp_size = parallel_args_->sp_size(); - int32_t cfg_size = parallel_args_->cfg_size(); - - process_group_ = create_process_group(global_rank, - world_size, - world_size, - ++port, + net::parse_host_port_from_addr(master_addr, host_, port_); + + process_group_ = create_process_group(global_rank_, + world_size_, + world_size_, + ++port_, false, - host, + host_, "world_group", device); parallel_args_->process_group_ = process_group_.get(); - if (tp_size > 1 && dit_mapping_) { - auto tp_parallel_info = dit_mapping_->get_parallel_info("tp"); - auto group_id = tp_parallel_info.current_group_id(); - auto num_group = tp_parallel_info.num_group(); - auto local_rank = tp_parallel_info.rank(); - auto& rank_per_group = tp_parallel_info.rank_per_group()[group_id]; - int port_offset = group_id + 1; -#if defined(USE_NPU) || defined(USE_MLU) - dit_tp_group_ = create_process_group(global_rank, - local_rank, - rank_per_group, - world_size, - tp_size, - port + port_offset, - host, - "tp_group", - device); -#endif - parallel_args_->dit_tp_group_ = dit_tp_group_.get(); - port += num_group; - } - - if (sp_size > 1 && dit_mapping_) { - auto sp_parallel_info = dit_mapping_->get_parallel_info("sp"); - auto group_id = sp_parallel_info.current_group_id(); - auto num_group = sp_parallel_info.num_group(); - auto local_rank = sp_parallel_info.rank(); - auto& rank_per_group = sp_parallel_info.rank_per_group()[group_id]; - int port_offset = group_id + 1; -#if defined(USE_NPU) || defined(USE_MLU) - dit_sp_group_ = create_process_group(global_rank, - local_rank, - rank_per_group, - world_size, - sp_size, - port + port_offset, - host, - "sp_group", - device); -#endif - parallel_args_->dit_sp_group_ = dit_sp_group_.get(); - port += num_group; - } - - if (cfg_size > 1 && dit_mapping_) { - auto cfg_parallel_info = dit_mapping_->get_parallel_info("cfg"); - auto group_id = cfg_parallel_info.current_group_id(); - auto num_group = cfg_parallel_info.num_group(); - auto local_rank = cfg_parallel_info.rank(); - auto& rank_per_group = cfg_parallel_info.rank_per_group()[group_id]; - int port_offset = group_id + 1; -#if defined(USE_NPU) || defined(USE_MLU) - dit_cfg_group_ = create_process_group(global_rank, - local_rank, - rank_per_group, - world_size, - cfg_size, - port + port_offset, - host, - "cfg_group", - device); -#endif - parallel_args_->dit_cfg_group_ = dit_cfg_group_.get(); - port += num_group; + std::unordered_map groups = { + {"tp", ¶llel_args_->dit_tp_group_}, + {"sp", ¶llel_args_->dit_sp_group_}, + {"cfg", ¶llel_args_->dit_cfg_group_}, + {"dp", ¶llel_args_->dit_dp_group_}, + {"vae", ¶llel_args_->dit_vae_group_}}; + + // we use class members to extend the lifetime of these ProceesGroups, so that + // they won't be destructed. + group_map_ = {{"tp", &dit_tp_group_}, + {"sp", &dit_sp_group_}, + {"dp", &dit_dp_group_}, + {"cfg", &dit_cfg_group_}, + {"vae", &dit_vae_group_}}; + + for (auto& [group_type, process_group] : groups) { + create_process_group_by_type(group_type, *process_group, device); } +} - if (dp_size > 1 && dit_mapping_) { - auto dp_parallel_info = dit_mapping_->get_parallel_info("dp"); - auto group_id = dp_parallel_info.current_group_id(); - auto num_group = dp_parallel_info.num_group(); - auto local_rank = dp_parallel_info.rank(); - auto& rank_per_group = dp_parallel_info.rank_per_group()[group_id]; +void DiTCollectiveCommunicator::create_process_group_by_type( + const std::string& group_type, + ProcessGroup*& process_group, + const torch::Device& device) { + int32_t group_size = parallel_args_->get_group_size_by_type(group_type); + if (group_size > 1 && dit_mapping_) { + auto parallel_info = dit_mapping_->get_parallel_info(group_type); + auto group_id = parallel_info.current_group_id(); + auto num_group = parallel_info.num_group(); + auto local_rank = parallel_info.rank(); + auto& rank_per_group = parallel_info.rank_per_group()[group_id]; int port_offset = group_id + 1; #if defined(USE_NPU) || defined(USE_MLU) - dit_dp_group_ = create_process_group(global_rank, - local_rank, - rank_per_group, - world_size, - dp_size, - port + port_offset, - host, - "dp_group", - device); + *group_map_[group_type] = + std::move(create_process_group(global_rank_, + local_rank, + rank_per_group, + world_size_, + group_size, + port_ + port_offset, + host_, + group_type + "_group", + device)); + process_group = (*group_map_[group_type]).get(); +#else + LOG(INFO) << "create_process_group function is used by DiT models, since " + "the DiT communication group " + << "info have already been calculated by rank_generator, we only " + "need to pass the " + << "info to create the process groups. For any device that want " + "to reuse the " + << "function and dit process groups, please implement the " + "corresponding " + << "ProcessGroupImpl construct function. "; #endif - parallel_args_->dit_dp_group_ = dit_dp_group_.get(); - port += num_group; + port_ += num_group; } } diff --git a/xllm/core/framework/parallel_state/dit_collective_communicator.h b/xllm/core/framework/parallel_state/dit_collective_communicator.h index 0acb68431..92593eba2 100644 --- a/xllm/core/framework/parallel_state/dit_collective_communicator.h +++ b/xllm/core/framework/parallel_state/dit_collective_communicator.h @@ -27,7 +27,8 @@ class DiTCollectiveCommunicator : public CollectiveCommunicatorBase { int32_t dit_dp_size, int32_t dit_tp_size, int32_t dit_sp_size, - int32_t dit_cfg_size); + int32_t dit_cfg_size, + int32_t dit_vae_size); ~DiTCollectiveCommunicator() = default; @@ -37,7 +38,13 @@ class DiTCollectiveCommunicator : public CollectiveCommunicatorBase { // init communicator and return parallel args. const ParallelArgs* parallel_args() override; + void create_process_group_by_type(const std::string& group_type, + ProcessGroup*& process_group, + const torch::Device& device); + private: + int32_t port_ = 0; + std::string host_ = ""; std::unique_ptr dit_mapping_{nullptr}; std::unique_ptr parallel_args_; std::unique_ptr process_group_; @@ -45,6 +52,8 @@ class DiTCollectiveCommunicator : public CollectiveCommunicatorBase { std::unique_ptr dit_sp_group_; std::unique_ptr dit_dp_group_; std::unique_ptr dit_cfg_group_; + std::unique_ptr dit_vae_group_; + std::unordered_map*> group_map_; }; } // namespace xllm diff --git a/xllm/core/framework/parallel_state/dit_mapping.cpp b/xllm/core/framework/parallel_state/dit_mapping.cpp index 96608d2c1..cdf2e8378 100644 --- a/xllm/core/framework/parallel_state/dit_mapping.cpp +++ b/xllm/core/framework/parallel_state/dit_mapping.cpp @@ -19,6 +19,8 @@ limitations under the License. namespace xllm { +bool RankGenerator::initialized_ = false; + DiTMapping::DiTMapping(const int32_t world_size, const int32_t rank, const Options& options) @@ -27,18 +29,26 @@ DiTMapping::DiTMapping(const int32_t world_size, sp_.backend("hccl"); cfg_.backend("hccl"); dp_.backend("hccl"); + vae_.backend("hccl"); parse_parallel_info(); validate(); - rank_generator_ = - std::make_unique(tp_.group_size(), - sp_.group_size(), - cfg_.group_size(), - dp_.group_size(), - /*group_order=*/"tp-sp-cfg-dp"); - set_group_by_type(tp_, "tp"); - set_group_by_type(sp_, "sp"); - set_group_by_type(cfg_, "cfg"); - set_group_by_type(dp_, "dp"); + RankGenerator::init(world_size); + std::vector group_ranks = { + tp_.group_size(), sp_.group_size(), cfg_.group_size(), dp_.group_size()}; + std::vector group_order = {"tp", "sp", "cfg", "dp"}; + auto ranks_mapping = + RankGenerator::getInstance().get_ranks_mapping(group_ranks, group_order); + + set_group_by_type(tp_, "tp", ranks_mapping->at("tp")); + set_group_by_type(sp_, "sp", ranks_mapping->at("sp")); + set_group_by_type(cfg_, "cfg", ranks_mapping->at("cfg")); + set_group_by_type(dp_, "dp", ranks_mapping->at("dp")); + + std::vector vae_group_ranks = {vae_.group_size()}; + std::vector vae_group_order = {"vae"}; + auto ranks_mapping_vae = RankGenerator::getInstance().get_ranks_mapping( + vae_group_ranks, vae_group_order); + set_group_by_type(vae_, "vae", ranks_mapping_vae->at("vae")); } void DiTMapping::parse_parallel_info() { @@ -54,6 +64,9 @@ void DiTMapping::parse_parallel_info() { if (options_.dit_dp_size() != -1) { dp_.group_size(options_.dit_dp_size()); } + if (options_.dit_vae_size() != -1) { + vae_.group_size(options_.dit_vae_size()); + } } void DiTMapping::validate() { @@ -78,15 +91,30 @@ void DiTMapping::validate() { ". " "Please check `cfg`, `tp`, `sp`, `dp` and `world_size`."; - CHECK(cfg_.group_size() <= 2) << "cfg_size must less than 2 " - "cfg_size is " + - std::to_string(cfg_.group_size()) + - ". Please check `cfg` ."; + CHECK(cfg_.group_size() <= 2 && cfg_.group_size() >= 1) + << "cfg_size must less than 2 " + "cfg_size is " + + std::to_string(cfg_.group_size()) + ". Please check `cfg` ."; + + CHECK(vae_.group_size() <= world_size_) + << "vae_size could not greater than world_size. " + "vae_size is " + + std::to_string(vae_.group_size()) + ", world_size is " + + std::to_string(world_size_) + + ". Please check `vae` and 'world_size'."; + + CHECK(world_size_ % vae_.group_size() == 0) + << "world_size could not be divided by world_size. " + "vae_size is " + + std::to_string(vae_.group_size()) + ", world_size is " + + std::to_string(world_size_) + + ". Please check `vae` and 'world_size'."; } -void DiTMapping::set_group_by_type(ParallelInfo& parallel_info, - const std::string& group_type) { - auto rank_per_group = rank_generator_->get_ranks(group_type); +void DiTMapping::set_group_by_type( + ParallelInfo& parallel_info, + const std::string& group_type, + std::vector> rank_per_group) { parallel_info.rank_per_group(rank_per_group); auto group_size = rank_per_group[0].size(); parallel_info.num_group(world_size_ / group_size); @@ -122,8 +150,10 @@ const ParallelInfo& DiTMapping::get_parallel_info( return cfg_; } else if (group_type == "dp") { return dp_; + } else if (group_type == "vae") { + return vae_; } else { - LOG(ERROR) << "get unexpected group_type: " << group_type; + LOG(FATAL) << "get unexpected group_type: " << group_type; } } @@ -139,6 +169,7 @@ nlohmann::json DiTMapping::to_json() { data["tp"] = tp_.to_json(); data["cfg"] = cfg_.to_json(); data["dp"] = dp_.to_json(); + data["vae"] = vae_.to_json(); return data; } diff --git a/xllm/core/framework/parallel_state/dit_mapping.h b/xllm/core/framework/parallel_state/dit_mapping.h index 02b0e1af8..dfb72a3d9 100644 --- a/xllm/core/framework/parallel_state/dit_mapping.h +++ b/xllm/core/framework/parallel_state/dit_mapping.h @@ -34,6 +34,8 @@ class DiTMapping final { PROPERTY(int32_t, dit_sp_size) = -1; // dp size PROPERTY(int32_t, dit_dp_size) = -1; + // vae size + PROPERTY(int32_t, dit_vae_size) = -1; }; DiTMapping(const int32_t world_size, @@ -47,7 +49,8 @@ class DiTMapping final { void validate(); void set_group_by_type(ParallelInfo& parallel_info, - const std::string& group_type); + const std::string& group_type, + std::vector> rank_per_group); std::tuple get_current_group_id( const std::vector>& rank_per_group, @@ -67,6 +70,6 @@ class DiTMapping final { ParallelInfo tp_ = ParallelInfo(); ParallelInfo cfg_ = ParallelInfo(); ParallelInfo dp_ = ParallelInfo(); - std::unique_ptr rank_generator_{nullptr}; + ParallelInfo vae_ = ParallelInfo(); }; } // namespace xllm diff --git a/xllm/core/framework/parallel_state/parallel_args.h b/xllm/core/framework/parallel_state/parallel_args.h index 4505a7191..b8fe9da5a 100644 --- a/xllm/core/framework/parallel_state/parallel_args.h +++ b/xllm/core/framework/parallel_state/parallel_args.h @@ -94,6 +94,7 @@ struct ParallelArgs { int32_t tp_size, int32_t sp_size, int32_t cfg_size, + int32_t vae_size, ProcessGroup* process_group) : rank_(rank), world_size_(world_size), @@ -101,8 +102,31 @@ struct ParallelArgs { tp_size_(tp_size), sp_size_(sp_size), cfg_size_(cfg_size), + vae_size_(vae_size), process_group_(process_group) {} + int32_t get_group_size_by_type(const std::string& group_type) const { + if (group_type == "tp") { + return tp_size(); + } else if (group_type == "sp") { + return sp_size(); + } else if (group_type == "cfg") { + return cfg_size(); + } else if (group_type == "dp") { + return dp_size(); + } else if (group_type == "ep") { + return ep_size(); + } else if (group_type == "vae") { + return vae_size(); + } else if (group_type == "cp") { + return cp_size(); + } else { + LOG(FATAL) << "get unexpected group_type: " << group_type; + return 1; + } + return 1; + } + // rank of current process PROPERTY(int32_t, rank) = 0; @@ -127,6 +151,9 @@ struct ParallelArgs { // cfg size PROPERTY(int32_t, cfg_size) = 1; + // cfg size + PROPERTY(int32_t, vae_size) = 1; + // atb hccl mapping json data PROPERTY(nlohmann::json, mapping_data); @@ -160,6 +187,7 @@ struct ParallelArgs { ProcessGroup* dit_sp_group_ = nullptr; ProcessGroup* dit_cfg_group_ = nullptr; ProcessGroup* dit_dp_group_ = nullptr; + ProcessGroup* dit_vae_group_ = nullptr; }; } // namespace xllm diff --git a/xllm/core/framework/parallel_state/rank_generator.h b/xllm/core/framework/parallel_state/rank_generator.h index 8370562a8..d05f50940 100644 --- a/xllm/core/framework/parallel_state/rank_generator.h +++ b/xllm/core/framework/parallel_state/rank_generator.h @@ -1,3 +1,19 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once #include #include @@ -7,73 +23,107 @@ #include "core/common/global_flags.h" #include "core/framework/config/dit_config.h" +namespace xllm { + +/* +group_ranks: the rank sizes of the sub groups +group_order: the priority of the sub groups, the group with a + higher priority will be assigned closer rank ids. +world_size: the global world_size +*/ class RankGenerator { public: - RankGenerator(int32_t tp, - int32_t sp, - int32_t cfg, - int32_t dp, - const std::string& group_order = "tp-sp-cfg-dp", - int32_t rank_offset = 0) - : tp_(tp), sp_(sp), cfg_(cfg), dp_(dp), rank_offset_(rank_offset) { - world_size_ = tp * sp * cfg * dp; - - group_size_map_["tp"] = tp; - group_size_map_["sp"] = sp; - group_size_map_["cfg"] = cfg; - group_size_map_["dp"] = dp; - - auto full_order = group_order; - for (const auto& group_size_pair : group_size_map_) { - const std::string& group_name = group_size_pair.first; - int32_t group_size = group_size_pair.second; - - if (full_order.find(group_name) == std::string::npos) { - if (group_size != 1) { - LOG(FATAL) << "The size of (" << group_name << ") is (" << group_size - << "), but you haven't specified it in order (" - << full_order << ")."; - } else { - full_order = full_order + "-" + group_name; - } - } + static void init(int32_t world_size, int32_t rank_offset = 0) { + if (initialized_) { + LOG(FATAL) << "repeated initialize RankGenerator"; + } else { + RankGenerator& instance = getInstance(); + instance = RankGenerator(world_size, rank_offset); + initialized_ = true; } + } - group_order_ = full_order; + static RankGenerator& getInstance() { + static RankGenerator instance; + return instance; + } - auto split = [](const std::string& s, - char delimiter) -> std::vector { - std::vector tokens; - std::string token; - std::istringstream tokenStream(s); - while (std::getline(tokenStream, token, delimiter)) { - tokens.push_back(token); - } - return tokens; - }; + std::shared_ptr< + std::unordered_map>>> + get_ranks_mapping(std::vector& group_ranks, + std::vector& group_order) { + CHECK(!group_ranks.empty() && group_ranks.size() != 0) + << "The RankGenerator expected to initialize with group_ranks that " + "contains the ranks of sub groups" + << ", but got an empty group_ranks"; + + CHECK(!group_order.empty() && group_order.size() != 0) + << "The RankGenerator expected to initialize with group_order that " + "indicates the priority of sub groups" + << ", but got empty string"; + + int32_t product_size = 1; + for (const auto& group_rank : group_ranks) { + product_size *= group_rank; + } - ordered_group_name_ = split(group_order_, '-'); - for (const std::string& token : ordered_group_name_) { - auto it = group_size_map_.find(token); - if (it != group_size_map_.end()) { - ordered_group_size_.push_back(it->second); + bool is_single_group = (group_ranks.size() == 1); + if (is_single_group && group_ranks[0] != world_size_) { + if (world_size_ % group_ranks[0] != 0) { + LOG(FATAL) << "The world_size could not be divided by vae_size, " + << "got world_size: " << world_size_ + << ", vae_size: " << group_ranks[0] << "."; } + LOG(WARNING) << "The sub group size does not equal world_size" + << ", we will assign the " << group_order[0] + << ", with sub group size: " << group_ranks[0]; + group_ranks.emplace_back(world_size_ / group_ranks[0]); + group_order.emplace_back("place_holder"); + } else if (world_size_ != product_size) { + LOG(FATAL) << "The world_size does not equals the product of sub " + "group sizes, " + << "got world_size: " << world_size_ + << ", sub groups: " << group_order[0] + << "sub groups sizes: " << group_ranks[0]; } - LOG(INFO) << "RankGenerator initialized with tp=" << tp << ", sp=" << sp - << ", cfg=" << cfg << ", dp=" << dp << ", order=" << group_order_ - << ", world_size=" << world_size_; + CHECK(group_order.size() == group_ranks.size()) + << "The size of group_ranks does not equals the size of group_order."; - if (::xllm::DiTConfig::get_instance().dit_debug_print()) { - debug_print(); + std::stringstream ss; + for (size_t i = 0; i < group_ranks.size(); i++) { + ss << group_order[i] << "=" << group_ranks[i] << ", "; } + + LOG(INFO) << "RankGenerator initialized with " << ss.str() + << "world_size=" << world_size_; + + auto group_mapping = std::make_shared< + std::unordered_map>>>(); + for (auto& group_name : group_order) { + auto sub_group_ranks = get_ranks(group_name, group_ranks, group_order); + if (::xllm::DiTConfig::get_instance().dit_debug_print()) { + print_ranks(group_name, sub_group_ranks); + } + group_mapping->insert({group_name, sub_group_ranks}); + } + return group_mapping; } - std::vector> get_ranks(const std::string& group_query) { - std::vector mask = get_mask(group_query); + int32_t get_world_size() const { return world_size_; } + static bool initialized_; + + private: + RankGenerator(int32_t world_size = 1, int32_t rank_offset = 0) + : world_size_(world_size), rank_offset_(rank_offset) {} + + std::vector> get_ranks( + const std::string& group_query, + const std::vector& group_ranks, + const std::vector& group_order) { + std::vector mask = get_mask(group_query, group_order); std::vector> ranks = - generate_masked_orthogonal_rank_groups( - world_size_, ordered_group_size_, mask); + generate_masked_orthogonal_rank_groups(world_size_, group_ranks, mask); if (rank_offset_ > 0) { for (auto& rank_group : ranks) { for (size_t i = 0; i < rank_group.size(); i++) { @@ -85,23 +135,8 @@ class RankGenerator { return ranks; } - int32_t get_world_size() const { return world_size_; } - const std::string& get_order() const { return group_order_; } - int32_t get_tp() const { return tp_; } - int32_t get_sp() const { return sp_; } - int32_t get_cfg() const { return cfg_; } - int32_t get_dp() const { return dp_; } - - void debug_print() { - print_ranks("cfg"); - print_ranks("tp"); - print_ranks("sp"); - print_ranks("dp"); - } - - void print_ranks(const std::string& group_query) { - auto ranks = get_ranks(group_query); - + void print_ranks(const std::string& group_query, + const std::vector>& ranks) { std::stringstream ss; ss << "Ranks for query '" << group_query << "':" << std::endl; for (size_t i = 0; i < ranks.size(); i++) { @@ -115,7 +150,6 @@ class RankGenerator { LOG(INFO) << ss.str(); } - private: std::vector prefix_product(const std::vector& group_size, int32_t init = 1) { std::vector prefix_product_sizes; @@ -206,7 +240,6 @@ class RankGenerator { // group size equals to the product of queryed group type sizes; int32_t group_size = queried_group_prefix.back(); int32_t num_of_group = world_size / group_size; - std::vector> ranks; for (int32_t group_index = 0; group_index < num_of_group; group_index++) { std::vector decomposed_group_idx = @@ -227,15 +260,25 @@ class RankGenerator { return ranks; } - std::vector get_mask(const std::string& group_query) { + std::vector get_mask(const std::string& group_query, + const std::vector& group_order) { + auto split = [](const std::string& s, + char delimiter) -> std::vector { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; + }; std::vector query_group_name = split(group_query, '-'); - std::vector mask(ordered_group_name_.size(), false); + std::vector mask(group_order.size(), false); for (const std::string& group_name : query_group_name) { - auto it = std::find( - ordered_group_name_.begin(), ordered_group_name_.end(), group_name); - if (it != ordered_group_name_.end()) { - size_t index = std::distance(ordered_group_name_.begin(), it); + auto it = std::find(group_order.begin(), group_order.end(), group_name); + if (it != group_order.end()) { + size_t index = std::distance(group_order.begin(), it); mask[index] = true; } } @@ -253,15 +296,8 @@ class RankGenerator { return tokens; } - private: - int32_t tp_; - int32_t sp_; - int32_t cfg_; - int32_t dp_; int32_t rank_offset_; int32_t world_size_; - std::string group_order_; - std::vector ordered_group_size_; - std::vector ordered_group_name_; - std::unordered_map group_size_map_; }; + +} // namespace xllm diff --git a/xllm/core/runtime/embed_vlm_worker_impl.cpp b/xllm/core/runtime/embed_vlm_worker_impl.cpp index d9f588cd5..b100a5cd8 100644 --- a/xllm/core/runtime/embed_vlm_worker_impl.cpp +++ b/xllm/core/runtime/embed_vlm_worker_impl.cpp @@ -85,7 +85,6 @@ std::optional EmbedVLMWorkerImpl::step( input.sampling_params.is_embeddings) { auto embeddings = model_->pooler(hidden_states, sampling_params.selected_token_idxes); - sample_output.embeddings = embeddings; // split full embeddings and add them to mm_embeddings // so that the user could receive embeddings of images and texts if (::xllm::ModelConfig::get_instance() @@ -99,10 +98,12 @@ std::optional EmbedVLMWorkerImpl::step( sample_output.mm_embeddings.emplace_back(image_embed); token_start_idx += seq_len; } + output.sample_output = sample_output; + } else { + sample_output.embeddings = embeddings; + output.sample_output = sample_output; + output.embedding = embeddings; } - - output.sample_output = sample_output; - output.embedding = embeddings; } ret = device_.synchronize_default_stream(); return output; diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 7907c55c5..ac8f79042 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -349,9 +349,10 @@ inline size_t get_dit_forward_input_size(const DiTForwardInput& input) { return size; } -inline size_t get_dit_forward_output_size(const DiTForwardOutput& output) { - size_t size = type_size; // vector size - for (const auto& tensor : output.tensors) { +inline size_t get_vector_tensor_size( + const std::vector& tensor_vec) { + size_t size = type_size; // vector size + for (const auto& tensor : tensor_vec) { size += get_tensor_size(tensor); } return size; @@ -1024,14 +1025,6 @@ inline void write_dit_forward_input(RawInputSerializeContext& context, write_dit_generation_params(context, input.generation_params); } -inline void write_dit_forward_output(char*& buffer, - const DiTForwardOutput& output) { - write_data(buffer, static_cast(output.tensors.size())); - for (const auto& tensor : output.tensors) { - write_tensor(buffer, tensor); - } -} - inline void safe_advance_buffer(const char*& buffer, size_t offset) { if (buffer != nullptr) { buffer += offset; @@ -1917,16 +1910,6 @@ inline void read_dit_forward_input(ReadContext& context, read_dit_generation_params(context, input.generation_params); } -inline void read_dit_forward_output(const char*& buffer, - DiTForwardOutput& output) { - uint64_t size; - read_data(buffer, size); - output.tensors.resize(size); - for (auto& tensor : output.tensors) { - read_tensor(buffer, tensor); - } -} - inline void initialize_device_buffer_session(ReadContext& context, ForwardInput& forward_input, const torch::Device& device, @@ -2156,7 +2139,9 @@ inline void deserialize_forward_input_payload( input_params.attention.host.block_tables, stream); - read_dit_forward_input(context, input_params.dit_forward_input); + if (FLAGS_backend == "dit") { + read_dit_forward_input(context, input_params.dit_forward_input); + } finalize_device_buffer_session(device_session, stream); forward_input.input_host_buffer_has_layout = true; @@ -2232,8 +2217,12 @@ size_t calculate_raw_forward_output_size(const RawForwardOutput& output) { size += get_vector_size(output.out_tokens); size += get_vector_size(output.out_logprobs); size += type_size; // prepared_layer_id + // mm_embedding_data + size += get_vector_tensor_size(output.mm_embeddings); // dit output data - size += get_dit_forward_output_size(output.dit_forward_output); + if (FLAGS_backend == "dit") { + size += get_vector_tensor_size(output.dit_forward_output.tensors); + } return size; } @@ -2299,8 +2288,11 @@ void deserialize_raw_forward_output(const char* buffer, read_data(buffer, output.prepared_layer_id); read_vector_tensor(buffer, output.mm_embeddings); + // read dit output - read_dit_forward_output(buffer, output.dit_forward_output); + if (FLAGS_backend == "dit") { + read_vector_tensor(buffer, output.dit_forward_output.tensors); + } } void serialize_raw_forward_output(const RawForwardOutput& output, @@ -2316,7 +2308,9 @@ void serialize_raw_forward_output(const RawForwardOutput& output, write_vector_tensor(buffer, output.mm_embeddings); // write dit output - write_dit_forward_output(buffer, output.dit_forward_output); + if (FLAGS_backend == "dit") { + write_vector_tensor(buffer, output.dit_forward_output.tensors); + } } template @@ -2459,8 +2453,9 @@ inline void serialize_forward_input_sections( context, choose_host_or_device_tensor(input_params.attention.host.block_tables, input_params.attention.device.block_tables)); - - write_dit_forward_input(context, input_params.dit_forward_input); + if (FLAGS_backend == "dit") { + write_dit_forward_input(context, input_params.dit_forward_input); + } } inline RawInputLayoutHeader calculate_forward_input_layout( diff --git a/xllm/core/runtime/options.h b/xllm/core/runtime/options.h index 5efab3e51..77bce9261 100644 --- a/xllm/core/runtime/options.h +++ b/xllm/core/runtime/options.h @@ -123,6 +123,10 @@ struct Options { // Default set as 1 PROPERTY(int32_t, cfg_size) = 1; + // vae patch parallelism size + // Default set as 1 + PROPERTY(int32_t, vae_size) = 1; + // enable enable_schedule_overlap to improve runtime execution efficiency. PROPERTY(bool, enable_schedule_overlap) = true; diff --git a/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h b/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h index 2d4ae4d2e..078df4327 100644 --- a/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h +++ b/xllm/models/dit/npu/qwen_image_edit/autoencoder_kl_qwenimage.h @@ -28,11 +28,13 @@ limitations under the License. #include #include +#include "core/common/global_flags.h" #include "core/framework/dit_model_loader.h" #include "core/framework/model/model_input_params.h" #include "core/framework/state_dict/state_dict.h" #include "core/layers/common/add_matmul.h" #include "framework/model_context.h" +#include "framework/parallel_state/parallel_state.h" #include "models/dit/utils/common_util.h" #include "models/model_registry.h" @@ -1451,15 +1453,18 @@ struct DecoderOutput { class AutoencoderKLQwenImageImpl : public torch::nn::Module { public: - AutoencoderKLQwenImageImpl(const ModelContext& context) + AutoencoderKLQwenImageImpl(const ModelContext& context, + const ParallelArgs& parallel_args) : args_(context.get_model_args()), + options_(context.get_tensor_options()), z_dim_(context.get_model_args().z_dim()), temperal_downsample_(context.get_model_args().temperal_downsample()), base_dim_(context.get_model_args().base_dim()), dim_mult_(context.get_model_args().dim_mult()), num_res_blocks_(context.get_model_args().num_res_blocks()), attn_scales_(context.get_model_args().attn_scales()), - dropout_(context.get_model_args().dropout()) { + dropout_(context.get_model_args().dropout()), + parallel_args_(parallel_args) { temperal_upsample_ = std::vector(temperal_downsample_.rbegin(), temperal_downsample_.rend()); @@ -1697,24 +1702,23 @@ class AutoencoderKLQwenImageImpl : public torch::nn::Module { auto weight_a = 1.0 - static_cast(y) / blend_extent; auto weight_b = static_cast(y) / blend_extent; - auto a_slice = a.index( - {torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(-blend_extent + y, -blend_extent + y + 1), - torch::indexing::Slice()}); + auto a_slice = a.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + -blend_extent + y, + torch::indexing::Slice()}); auto b_slice = result_b.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), - torch::indexing::Slice(y, y + 1), + y, torch::indexing::Slice()}); auto blended = a_slice * weight_a + b_slice * weight_b; result_b.index_put_({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), - torch::indexing::Slice(y, y + 1), + y, torch::indexing::Slice()}, blended); } @@ -1732,31 +1736,60 @@ class AutoencoderKLQwenImageImpl : public torch::nn::Module { auto weight_a = 1.0 - static_cast(x) / blend_extent; auto weight_b = static_cast(x) / blend_extent; - auto a_slice = a.index( - {torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(-blend_extent + x, -blend_extent + x + 1)}); + auto a_slice = a.index({torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(), + -blend_extent + x}); auto b_slice = result_b.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), - torch::indexing::Slice(x, x + 1)}); + x}); auto blended = a_slice * weight_a + b_slice * weight_b; result_b.index_put_({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), - torch::indexing::Slice(x, x + 1)}, + x}, blended); } return result_b; } + void pad_to_size(torch::Tensor& input, int64_t target_h, int64_t target_w) { + int64_t input_h = input.size(3); + int64_t input_w = input.size(4); + + int64_t pad_top = (target_h - input_h) / 2; + int64_t pad_bottom = target_h - input_h - pad_top; + int64_t pad_left = (target_w - input_w) / 2; + int64_t pad_right = target_w - input_w - pad_left; + + input = torch::nn::functional::pad( + input, + torch::nn::functional::PadFuncOptions( + {pad_left, pad_right, pad_top, pad_bottom}) + .mode(torch::kConstant) + .value(0)); + } + + void unpad_to_size(torch::Tensor& padded, + int64_t original_h, + int64_t original_w) { + int64_t current_h = padded.size(3); + int64_t current_w = padded.size(4); + + int64_t pad_top = (current_h - original_h) / 2; + int64_t pad_left = (current_w - original_w) / 2; + + padded = padded.slice(3, pad_top, pad_top + original_h) + .slice(4, pad_left, pad_left + original_w); + } + torch::Tensor tiled_encode(const torch::Tensor& x) { auto sizes = x.sizes(); auto b = sizes[0], c = sizes[1], num_frames = sizes[2], height = sizes[3], @@ -1779,44 +1812,173 @@ class AutoencoderKLQwenImageImpl : public torch::nn::Module { std::vector> rows; + // dispatch rows to different ranks + ProcessGroup* vae_group = parallel_args_.dit_vae_group_; + int32_t group_size = vae_group->world_size(); + int32_t local_rank = vae_group->rank(); + bool use_vae_parallel = false; + + // num of patchs on different devices + std::vector num_rank_patchs; + num_rank_patchs.reserve(group_size); + + int32_t local_rows = 0; + int32_t remainder_local_rows = 0; + + int32_t row_start = 0; + int32_t row_end = 0; + int32_t row_size = 0; + if (group_size > 1) { + int32_t num_rows = height / tile_sample_stride_height_; + int32_t remainder_row = (height % tile_sample_stride_height_) > 0 ? 1 : 0; + int32_t global_rows = num_rows + remainder_row; + local_rows = global_rows / group_size; + remainder_local_rows = global_rows % group_size; + row_start = local_rows * local_rank + (local_rank < remainder_local_rows + ? local_rank + : remainder_local_rows); + row_end = row_start + local_rows + + ((local_rank) < remainder_local_rows ? 1 : 0); + use_vae_parallel = local_rows > 1; + } + + // Save tensor and tensor shapes for vae parallel situations. + int32_t row_counter = 0; + std::vector shape_tensors; + std::vector patch_tensors; + for (int64_t i = 0; i < height; i += tile_sample_stride_height_) { - std::vector row; + if ((use_vae_parallel && row_counter >= row_start && + row_counter < row_end) || + !use_vae_parallel) { + std::vector row; + for (int64_t j = 0; j < width; j += tile_sample_stride_width_) { + clear_cache(); + std::vector time_frames; + auto frame_range = 1 + (num_frames - 1) / 4; + + for (int64_t k = 0; k < frame_range; k++) { + enc_conv_idx_->at(0) = 0; + torch::Tensor tile; + + if (k == 0) { + tile = x.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(0, 1), + torch::indexing::Slice(i, i + tile_sample_min_height_), + torch::indexing::Slice(j, j + tile_sample_min_width_)}); + } else { + tile = x.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(1 + 4 * (k - 1), 1 + 4 * k), + torch::indexing::Slice(i, i + tile_sample_min_height_), + torch::indexing::Slice(j, j + tile_sample_min_width_)}); + } + + auto encoded_tile = + encoder_->forward(tile, enc_feat_map_, enc_conv_idx_); + auto quantized_tile = quant_conv_->forward(encoded_tile); + time_frames.push_back(quantized_tile); + } - for (int64_t j = 0; j < width; j += tile_sample_stride_width_) { - clear_cache(); - std::vector time_frames; - auto frame_range = 1 + (num_frames - 1) / 4; - - for (int64_t k = 0; k < frame_range; k++) { - enc_conv_idx_->at(0) = 0; - torch::Tensor tile; - - if (k == 0) { - tile = x.index( - {torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(0, 1), - torch::indexing::Slice(i, i + tile_sample_min_height_), - torch::indexing::Slice(j, j + tile_sample_min_width_)}); + if (use_vae_parallel) { + patch_tensors.emplace_back(torch::cat(time_frames, 2)); + shape_tensors.emplace_back( + torch::tensor(patch_tensors.back().sizes(), options_) + .unsqueeze(0)); } else { - tile = x.index( - {torch::indexing::Slice(), - torch::indexing::Slice(), - torch::indexing::Slice(1 + 4 * (k - 1), 1 + 4 * k), - torch::indexing::Slice(i, i + tile_sample_min_height_), - torch::indexing::Slice(j, j + tile_sample_min_width_)}); + row.push_back(torch::cat(time_frames, 2)); } + } - auto encoded_tile = - encoder_->forward(tile, enc_feat_map_, enc_conv_idx_); - auto quantized_tile = quant_conv_->forward(encoded_tile); - time_frames.push_back(quantized_tile); + if (!use_vae_parallel) { + rows.push_back(row); + } else if (row_size == 0) { + row_size = patch_tensors.size(); } + } + row_counter += 1; + } - row.push_back(torch::cat(time_frames, 2)); + if (use_vae_parallel) { + // calculate num of patchs on different ranks + for (int32_t rank = 0; rank < group_size; rank++) { + int64_t rank_patch = + row_size * (local_rows + ((rank) < remainder_local_rows ? 1 : 0)); + num_rank_patchs.emplace_back(rank_patch); + } + + std::vector gather_shapes; + gather_shapes.reserve(group_size); + for (int32_t rank = 0; rank < group_size; rank++) { + gather_shapes.emplace_back( + torch::empty({num_rank_patchs[rank], 5}, options_)); + gather_shapes.back().print(); + } + auto shape_tensor = torch::cat(shape_tensors, /*dim=*/0); + // gather shapes of patchs on different ranks + vae_group->allgather(shape_tensor, gather_shapes); + + // use the shape of first patch as uniform shape + // we will pad the patchs to the same size to apply all gather + auto uniform_shape = gather_shapes[0][0]; + auto pb = uniform_shape[0].item(), + pc = uniform_shape[1].item(), + pf = uniform_shape[2].item(), + ph = uniform_shape[3].item(), + pw = uniform_shape[4].item(); + + for (auto& patch : patch_tensors) { + pad_to_size(patch, ph, pw); + } + + // cat the patchs alone the last dim + auto merged_local_patch = torch::cat(patch_tensors, /*dim=*/4); + + std::vector gather_patchs; + for (int32_t rank = 0; rank < group_size; rank++) { + auto rank_pw = num_rank_patchs[rank] * pw; + gather_patchs.emplace_back( + torch::empty({pb, pc, pf, ph, rank_pw}, options_)); + } + + // gather patchs on different ranks + vae_group->allgather(merged_local_patch, gather_patchs); + + // chunk patchs to the rows + for (int32_t rank = 0; rank < group_size; rank++) { + auto& rank_patch = gather_patchs[rank]; + auto rank_rows = num_rank_patchs[rank] / row_size; + std::vector patchs = + rank_patch.chunk(num_rank_patchs[rank], -1); + auto& origin_shape = gather_shapes[rank]; + int32_t patch_index = 0; + for (int i = 0; i < rank_rows; i++) { + std::vector row; + for (int j = 0; j < row_size; j++) { + if (patch_index >= patchs.size() || + patch_index >= origin_shape.size(0)) { + LOG(FATAL) << "Patch index " << patch_index + << " is out of bounds for patchs (size " + << patchs.size() << ") or origin_shape (size " + << origin_shape.size(0) << ")."; + } + // unpad patchs to origin size; + unpad_to_size(patchs[patch_index], + origin_shape[patch_index][3].item(), + origin_shape[patch_index][4].item()); + patchs[patch_index].print(); + row.emplace_back(patchs[patch_index]); + patch_index++; + } + // push back to rows; + rows.push_back(row); + } } - rows.push_back(row); } + clear_cache(); std::vector result_rows; @@ -2028,6 +2190,8 @@ class AutoencoderKLQwenImageImpl : public torch::nn::Module { QwenImageDecoder3d decoder_{nullptr}; ModelArgs args_; + const ParallelArgs parallel_args_; + torch::TensorOptions options_; }; TORCH_MODULE(AutoencoderKLQwenImage); diff --git a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h index 399fc4933..6c9b6e3e3 100644 --- a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h +++ b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h @@ -54,7 +54,11 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { prompt_template_encode_start_idx_ = 64; default_sample_size_ = 128; - vae_ = AutoencoderKLQwenImage(context.get_model_context("vae")); + vae_ = AutoencoderKLQwenImage(context.get_model_context("vae"), + parallel_args_); + if (::xllm::DiTConfig::get_instance().enable_dit_vae_tiling()) { + vae_->enable_tiling(); + } transformer_ = QwenImageTransformer2DModel( context.get_model_context("transformer"), parallel_args_); scheduler_ = diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index da493adac..bd85e7e0a 100644 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -208,6 +208,7 @@ Options create_options(const std::string& instance_name, bool is_local) { .tp_size(static_cast(parallel_config.tp_size())) .sp_size(static_cast(parallel_config.sp_size())) .cfg_size(static_cast(parallel_config.cfg_size())) + .vae_size(static_cast(parallel_config.vae_size())) .instance_name(instance_name) .enable_disagg_pd(disagg_pd_config.enable_disagg_pd()) .enable_pd_ooc(disagg_pd_config.enable_pd_ooc())