Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ DECLARE_int64(sp_size);

DECLARE_int64(cfg_size);

DECLARE_int64(vae_size);

DECLARE_bool(enable_prefill_sp);

DECLARE_bool(enable_multi_stream_parallel);
Expand Down Expand Up @@ -333,6 +335,8 @@ DECLARE_int64(dit_sp_communication_overlap);

DECLARE_bool(dit_debug_print);

DECLARE_bool(enable_dit_vae_tiling);

DECLARE_bool(use_audio_in_video);

// --- kernel config ---
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class Options {

PROPERTY(int32_t, cfg_size) = 1;

PROPERTY(int32_t, vae_size) = 1;

PROPERTY(std::optional<std::string>, instance_name);

PROPERTY(bool, enable_disagg_pd) = false;
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/distributed_runtime/dist_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,13 @@ void DistManager::setup_multi_node_workers(
const int32_t tp_size = options.tp_size();
const int32_t sp_size = options.sp_size();
const int32_t cfg_size = options.cfg_size();
const int32_t vae_size = options.vae_size();
LOG(INFO) << "Multi-node serving world_size = " << world_size
<< ", each_node_ranks = " << each_node_ranks
<< ", current node rank = " << options.node_rank()
<< ", nnodes = " << options.nnodes() << ", dp_size = " << dp_size
<< ", tp_size = " << tp_size << ", sp_size = " << sp_size
<< ", cfg_size = " << cfg_size;
<< ", cfg_size = " << cfg_size << ", vae_size = " << vae_size;
} else {
LOG(INFO) << "Multi-node serving world_size = " << world_size
<< ", each_node_ranks = " << each_node_ranks
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/distributed_runtime/master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ Master::Master(const Options& options, EngineType type)
.ep_size(options_.ep_size())
.tp_size(options_.tp_size())
.sp_size(options_.sp_size())
.cfg_size(options_.cfg_size());
.cfg_size(options_.cfg_size())
.vae_size(options_.vae_size());

auto dit_engine = std::make_unique<DiTEngine>(eng_options);
engine_ = std::move(dit_engine);
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/distributed_runtime/worker_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ void WorkerServer::create_server(
options.dp_size(),
options.tp_size(),
options.sp_size(),
options.cfg_size());
options.cfg_size(),
options.vae_size());
comm = std::move(dit_comm);
} else {
auto common_comm = std::make_unique<CollectiveCommunicator>(
Expand Down
14 changes: 12 additions & 2 deletions xllm/core/framework/config/dit_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ DEFINE_bool(dit_debug_print,
false,
"whether print the debug info for dit models");

// --- dit vae tiling ---

DEFINE_bool(
enable_dit_vae_tiling,
false,
"whether enable vae tiling, currently only support qwen-image-edit-plus");

namespace xllm {

void DiTConfig::from_flags() {
Expand All @@ -76,7 +83,8 @@ void DiTConfig::from_flags() {
.dit_cache_start_blocks(FLAGS_dit_cache_start_blocks)
.dit_cache_end_blocks(FLAGS_dit_cache_end_blocks)
.dit_sp_communication_overlap(FLAGS_dit_sp_communication_overlap)
.dit_debug_print(FLAGS_dit_debug_print);
.dit_debug_print(FLAGS_dit_debug_print)
.enable_dit_vae_tiling(FLAGS_enable_dit_vae_tiling);
}

void DiTConfig::from_json(const JsonReader& json) {
Expand Down Expand Up @@ -104,7 +112,9 @@ void DiTConfig::from_json(const JsonReader& json) {
.dit_sp_communication_overlap(json.value_or<int64_t>(
"dit_sp_communication_overlap", dit_sp_communication_overlap()))
.dit_debug_print(
json.value_or<bool>("dit_debug_print", dit_debug_print()));
json.value_or<bool>("dit_debug_print", dit_debug_print()))
.enable_dit_vae_tiling(json.value_or<bool>("enable_dit_vae_tiling",
enable_dit_vae_tiling()));
}

DiTConfig& DiTConfig::get_instance() {
Expand Down
5 changes: 4 additions & 1 deletion xllm/core/framework/config/dit_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class DiTConfig final {
"dit_cache_start_blocks",
"dit_cache_end_blocks",
"dit_sp_communication_overlap",
"dit_debug_print"}};
"dit_debug_print",
"enable_dit_vae_tiling"}};
return kOptionCategory;
}

Expand All @@ -77,6 +78,8 @@ class DiTConfig final {
PROPERTY(int64_t, dit_sp_communication_overlap) = 1;

PROPERTY(bool, dit_debug_print) = false;

PROPERTY(bool, enable_dit_vae_tiling) = false;
};

} // namespace xllm
4 changes: 4 additions & 0 deletions xllm/core/framework/config/parallel_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ DEFINE_int64(cfg_size,
"Classifier-free guidiance parallelism size, only used for DiT "
"model.");

DEFINE_int64(vae_size, 1, "Vae patch parallelism size");

DEFINE_string(
communication_backend,
"hccl",
Expand Down Expand Up @@ -67,6 +69,7 @@ void ParallelConfig::from_flags() {
.tp_size(FLAGS_tp_size)
.sp_size(FLAGS_sp_size)
.cfg_size(FLAGS_cfg_size)
.vae_size(FLAGS_vae_size)
.communication_backend(FLAGS_communication_backend)
.enable_prefill_sp(FLAGS_enable_prefill_sp)
.enable_multi_stream_parallel(FLAGS_enable_multi_stream_parallel)
Expand All @@ -81,6 +84,7 @@ void ParallelConfig::from_json(const JsonReader& json) {
.tp_size(json.value_or<int64_t>("tp_size", tp_size()))
.sp_size(json.value_or<int64_t>("sp_size", sp_size()))
.cfg_size(json.value_or<int64_t>("cfg_size", cfg_size()))
.vae_size(json.value_or<int64_t>("vae_size", vae_size()))
.communication_backend(json.value_or<std::string>(
"communication_backend", communication_backend()))
.enable_prefill_sp(
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/config/parallel_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ParallelConfig final {
"tp_size",
"sp_size",
"cfg_size",
"vae_size",
"communication_backend",
"enable_prefill_sp",
"enable_multi_stream_parallel",
Expand All @@ -65,6 +66,8 @@ class ParallelConfig final {

PROPERTY(int64_t, cfg_size) = 1;

PROPERTY(int64_t, vae_size) = 1;

PROPERTY(std::string, communication_backend) = "hccl";

PROPERTY(bool, enable_prefill_sp) = false;
Expand Down
162 changes: 63 additions & 99 deletions xllm/core/framework/parallel_state/dit_collective_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,23 @@ DiTCollectiveCommunicator::DiTCollectiveCommunicator(int32_t global_rank,
int32_t dit_dp_size,
int32_t dit_tp_size,
int32_t dit_sp_size,
int32_t dit_cfg_size)
int32_t dit_cfg_size,
int32_t dit_vae_size)
: CollectiveCommunicatorBase(global_rank, world_size) {
parallel_args_ = std::make_unique<ParallelArgs>(global_rank,
world_size,
dit_dp_size,
dit_tp_size,
dit_sp_size,
dit_cfg_size,
dit_vae_size,
/*process_group=*/nullptr);
DiTMapping::Options dit_mapping_options;
dit_mapping_options.dit_tp_size(dit_tp_size)
.dit_sp_size(dit_sp_size)
.dit_cfg_size(dit_cfg_size)
.dit_dp_size(dit_dp_size);
.dit_dp_size(dit_dp_size)
.dit_vae_size(dit_vae_size);
dit_mapping_ = std::make_unique<DiTMapping>(
world_size, global_rank, dit_mapping_options);
}
Expand All @@ -63,114 +66,75 @@ void DiTCollectiveCommunicator::create_process_groups(
const torch::Device& device) {
Device device_(device);
device_.set_device();
std::string host;
int32_t port;
net::parse_host_port_from_addr(master_addr, host, port);

int32_t global_rank = parallel_args_->rank();
int32_t world_size = parallel_args_->world_size();
int32_t dp_size = parallel_args_->dp_size();
int32_t tp_size = parallel_args_->tp_size();
int32_t sp_size = parallel_args_->sp_size();
int32_t cfg_size = parallel_args_->cfg_size();

process_group_ = create_process_group(global_rank,
world_size,
world_size,
++port,
net::parse_host_port_from_addr(master_addr, host_, port_);

process_group_ = create_process_group(global_rank_,
world_size_,
world_size_,
++port_,
false,
host,
host_,
"world_group",
device);

parallel_args_->process_group_ = process_group_.get();

if (tp_size > 1 && dit_mapping_) {
auto tp_parallel_info = dit_mapping_->get_parallel_info("tp");
auto group_id = tp_parallel_info.current_group_id();
auto num_group = tp_parallel_info.num_group();
auto local_rank = tp_parallel_info.rank();
auto& rank_per_group = tp_parallel_info.rank_per_group()[group_id];
int port_offset = group_id + 1;
#if defined(USE_NPU) || defined(USE_MLU)
dit_tp_group_ = create_process_group(global_rank,
local_rank,
rank_per_group,
world_size,
tp_size,
port + port_offset,
host,
"tp_group",
device);
#endif
parallel_args_->dit_tp_group_ = dit_tp_group_.get();
port += num_group;
}

if (sp_size > 1 && dit_mapping_) {
auto sp_parallel_info = dit_mapping_->get_parallel_info("sp");
auto group_id = sp_parallel_info.current_group_id();
auto num_group = sp_parallel_info.num_group();
auto local_rank = sp_parallel_info.rank();
auto& rank_per_group = sp_parallel_info.rank_per_group()[group_id];
int port_offset = group_id + 1;
#if defined(USE_NPU) || defined(USE_MLU)
dit_sp_group_ = create_process_group(global_rank,
local_rank,
rank_per_group,
world_size,
sp_size,
port + port_offset,
host,
"sp_group",
device);
#endif
parallel_args_->dit_sp_group_ = dit_sp_group_.get();
port += num_group;
}

if (cfg_size > 1 && dit_mapping_) {
auto cfg_parallel_info = dit_mapping_->get_parallel_info("cfg");
auto group_id = cfg_parallel_info.current_group_id();
auto num_group = cfg_parallel_info.num_group();
auto local_rank = cfg_parallel_info.rank();
auto& rank_per_group = cfg_parallel_info.rank_per_group()[group_id];
int port_offset = group_id + 1;
#if defined(USE_NPU) || defined(USE_MLU)
dit_cfg_group_ = create_process_group(global_rank,
local_rank,
rank_per_group,
world_size,
cfg_size,
port + port_offset,
host,
"cfg_group",
device);
#endif
parallel_args_->dit_cfg_group_ = dit_cfg_group_.get();
port += num_group;
std::unordered_map<std::string, ProcessGroup**> groups = {
{"tp", &parallel_args_->dit_tp_group_},
{"sp", &parallel_args_->dit_sp_group_},
{"cfg", &parallel_args_->dit_cfg_group_},
{"dp", &parallel_args_->dit_dp_group_},
{"vae", &parallel_args_->dit_vae_group_}};

// we use class members to extend the lifetime of these ProceesGroups, so that
// they won't be destructed.
group_map_ = {{"tp", &dit_tp_group_},
{"sp", &dit_sp_group_},
{"dp", &dit_dp_group_},
{"cfg", &dit_cfg_group_},
{"vae", &dit_vae_group_}};

for (auto& [group_type, process_group] : groups) {
create_process_group_by_type(group_type, *process_group, device);
}
}

if (dp_size > 1 && dit_mapping_) {
auto dp_parallel_info = dit_mapping_->get_parallel_info("dp");
auto group_id = dp_parallel_info.current_group_id();
auto num_group = dp_parallel_info.num_group();
auto local_rank = dp_parallel_info.rank();
auto& rank_per_group = dp_parallel_info.rank_per_group()[group_id];
void DiTCollectiveCommunicator::create_process_group_by_type(
const std::string& group_type,
ProcessGroup*& process_group,
const torch::Device& device) {
int32_t group_size = parallel_args_->get_group_size_by_type(group_type);
if (group_size > 1 && dit_mapping_) {
auto parallel_info = dit_mapping_->get_parallel_info(group_type);
auto group_id = parallel_info.current_group_id();
auto num_group = parallel_info.num_group();
auto local_rank = parallel_info.rank();
auto& rank_per_group = parallel_info.rank_per_group()[group_id];
int port_offset = group_id + 1;
#if defined(USE_NPU) || defined(USE_MLU)
dit_dp_group_ = create_process_group(global_rank,
local_rank,
rank_per_group,
world_size,
dp_size,
port + port_offset,
host,
"dp_group",
device);
*group_map_[group_type] =
std::move(create_process_group(global_rank_,
local_rank,
rank_per_group,
world_size_,
group_size,
port_ + port_offset,
host_,
group_type + "_group",
device));
process_group = (*group_map_[group_type]).get();
#else
LOG(INFO) << "create_process_group function is used by DiT models, since "
"the DiT communication group "
<< "info have already been calculated by rank_generator, we only "
"need to pass the "
<< "info to create the process groups. For any device that want "
"to reuse the "
<< "function and dit process groups, please implement the "
"corresponding "
<< "ProcessGroupImpl construct function. ";
#endif
parallel_args_->dit_dp_group_ = dit_dp_group_.get();
port += num_group;
port_ += num_group;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class DiTCollectiveCommunicator : public CollectiveCommunicatorBase {
int32_t dit_dp_size,
int32_t dit_tp_size,
int32_t dit_sp_size,
int32_t dit_cfg_size);
int32_t dit_cfg_size,
int32_t dit_vae_size);

~DiTCollectiveCommunicator() = default;

Expand All @@ -37,14 +38,22 @@ class DiTCollectiveCommunicator : public CollectiveCommunicatorBase {
// init communicator and return parallel args.
const ParallelArgs* parallel_args() override;

void create_process_group_by_type(const std::string& group_type,
ProcessGroup*& process_group,
const torch::Device& device);

private:
int32_t port_ = 0;
std::string host_ = "";
std::unique_ptr<DiTMapping> dit_mapping_{nullptr};
std::unique_ptr<ParallelArgs> parallel_args_;
std::unique_ptr<ProcessGroup> process_group_;
std::unique_ptr<ProcessGroup> dit_tp_group_;
std::unique_ptr<ProcessGroup> dit_sp_group_;
std::unique_ptr<ProcessGroup> dit_dp_group_;
std::unique_ptr<ProcessGroup> dit_cfg_group_;
std::unique_ptr<ProcessGroup> dit_vae_group_;
std::unordered_map<std::string, std::unique_ptr<ProcessGroup>*> group_map_;
};

} // namespace xllm
Loading
Loading