Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 54 additions & 29 deletions xllm/models/dit/npu/wan2_2/pipeline_wan2_2_i2v.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include <torch/torch.h>

#include <algorithm>
#include <cstring>
#include <memory>
#include <string>

#include "autoencoder_kl_wan.h"
#include "core/framework/dit_model_loader.h"
Expand Down Expand Up @@ -423,37 +423,25 @@ class Wan2_2I2VPipelineImpl : public torch::nn::Module {

int64_t dw = vae_scale_factor_spatial_ * patch_size_w;
int64_t dh = vae_scale_factor_spatial_ * patch_size_h;
int64_t max_area = height * width; // size param treated as area constraint

// Get actual image dimensions to derive aspect ratio
int64_t ih = images.has_value() ? images.value().size(-2) : height;
int64_t iw = images.has_value() ? images.value().size(-1) : width;
double ratio = static_cast<double>(iw) / ih;

// Candidate 1: floor width first
int64_t ow1 = static_cast<int64_t>(std::sqrt(max_area * ratio)) / dw * dw;
int64_t oh1 = (max_area / ow1) / dh * dh;
double ratio1 = static_cast<double>(ow1) / oh1;
// 硬编码开关:true = 用户尺寸优先,false = 原有图像宽高比
static const bool kUseUserSizePriority = false; // 按需修改此处

// Candidate 2: floor height first
int64_t oh2 = static_cast<int64_t>(std::sqrt(max_area / ratio)) / dh * dh;
int64_t ow2 = (max_area / oh2) / dw * dw;
double ratio2 = static_cast<double>(ow2) / oh2;

// Pick the one that preserves aspect ratio better
int64_t calc_height, calc_width;
if (std::max(ratio / ratio1, ratio1 / ratio) <
std::max(ratio / ratio2, ratio2 / ratio)) {
calc_width = ow1;
calc_height = oh1;
if (kUseUserSizePriority) {
// 用户尺寸最高优先级:仅强制对齐到16倍数
if (height % dh != 0) {
height = (height / dh) * dh;
LOG(WARNING) << "Height adjusted to " << height << " (multiple of "
<< dh << ")";
}
if (width % dw != 0) {
width = (width / dw) * dw;
LOG(WARNING) << "Width adjusted to " << width << " (multiple of " << dw
<< ")";
}
} else {
calc_width = ow2;
calc_height = oh2;
}

if (height != calc_height || width != calc_width) {
height = calc_height;
width = calc_width;
// 调用封装好的原有逻辑函数
AdjustSizeByAspectRatio(images, height, width, dw, dh);
}

if (boundary_ratio_ > 0.0f && guidance_scale_2 < 0.0f) {
Expand Down Expand Up @@ -645,6 +633,43 @@ class Wan2_2I2VPipelineImpl : public torch::nn::Module {
return video;
}

// 封装原有尺寸逻辑
void AdjustSizeByAspectRatio(const std::optional<torch::Tensor>& images,
int64_t& height,
int64_t& width,
int64_t dw,
int64_t dh) {
int64_t max_area = height * width;
int64_t ih = images.has_value() ? images.value().size(-2) : height;
int64_t iw = images.has_value() ? images.value().size(-1) : width;
double ratio = static_cast<double>(iw) / ih;

int64_t ow1 = static_cast<int64_t>(std::sqrt(max_area * ratio)) / dw * dw;
int64_t oh1 = (max_area / ow1) / dh * dh;
double ratio1 = static_cast<double>(ow1) / oh1;

int64_t oh2 = static_cast<int64_t>(std::sqrt(max_area / ratio)) / dh * dh;
int64_t ow2 = (max_area / oh2) / dw * dw;
double ratio2 = static_cast<double>(ow2) / oh2;

int64_t calc_height, calc_width;
if (std::max(ratio / ratio1, ratio1 / ratio) <
std::max(ratio / ratio2, ratio2 / ratio)) {
calc_width = ow1;
calc_height = oh1;
} else {
calc_width = ow2;
calc_height = oh2;
}

if (height != calc_height || width != calc_width) {
height = calc_height;
width = calc_width;
LOG(INFO) << "Size adjusted by aspect ratio: height=" << height
<< ", width=" << width;
}
}

private:
UniPCMultistepScheduler scheduler_{nullptr};
AutoencoderKLWan vae_{nullptr};
Expand Down