Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions xllm/models/dit/npu/wan2_2/pipeline_wan2_2_i2v.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#include <torch/torch.h>

#include <algorithm>
#include <cstring>
#include <memory>
#include <string>

#include "autoencoder_kl_wan.h"
#include "core/framework/dit_model_loader.h"
Expand Down Expand Up @@ -409,38 +409,9 @@ class Wan2_2I2VPipelineImpl : public torch::nn::Module {

int64_t dw = vae_scale_factor_spatial_ * patch_size_w;
int64_t dh = vae_scale_factor_spatial_ * patch_size_h;
int64_t max_area = height * width; // size param treated as area constraint

// Get actual image dimensions to derive aspect ratio
int64_t ih = images.has_value() ? images.value().size(-2) : height;
int64_t iw = images.has_value() ? images.value().size(-1) : width;
double ratio = static_cast<double>(iw) / ih;

// Candidate 1: floor width first
int64_t ow1 = static_cast<int64_t>(std::sqrt(max_area * ratio)) / dw * dw;
int64_t oh1 = (max_area / ow1) / dh * dh;
double ratio1 = static_cast<double>(ow1) / oh1;

// Candidate 2: floor height first
int64_t oh2 = static_cast<int64_t>(std::sqrt(max_area / ratio)) / dh * dh;
int64_t ow2 = (max_area / oh2) / dw * dw;
double ratio2 = static_cast<double>(ow2) / oh2;

// Pick the one that preserves aspect ratio better
int64_t calc_height, calc_width;
if (std::max(ratio / ratio1, ratio1 / ratio) <
std::max(ratio / ratio2, ratio2 / ratio)) {
calc_width = ow1;
calc_height = oh1;
} else {
calc_width = ow2;
calc_height = oh2;
}

if (height != calc_height || width != calc_width) {
height = calc_height;
width = calc_width;
}
// Call unified function for dimension adjustment
AdjustVideoSize(images, height, width, dw, dh, false);

if (boundary_ratio_ > 0.0f && guidance_scale_2 < 0.0f) {
guidance_scale_2 = guidance_scale;
Expand Down Expand Up @@ -628,6 +599,62 @@ class Wan2_2I2VPipelineImpl : public torch::nn::Module {
return video;
}

void AdjustVideoSize(const std::optional<torch::Tensor>& images,
int64_t& height,
int64_t& width,
int64_t dw,
int64_t dh,
bool use_user_priority) {
if (use_user_priority) {
// User priority mode: directly use user-specified dimensions
// Only enforce alignment to 16x multiple (dw, dh)
if (height % dh != 0) {
height = (height / dh) * dh;
LOG(WARNING) << "Height adjusted to " << height << " (multiple of "
<< dh << ")";
}
if (width % dw != 0) {
width = (width / dw) * dw;
LOG(WARNING) << "Width adjusted to " << width << " (multiple of " << dw
<< ")";
}
} else {
// Original logic: choose best candidate based on aspect ratio and area
int64_t max_area = height * width;
int64_t ih = images.has_value() ? images.value().size(-2) : height;
int64_t iw = images.has_value() ? images.value().size(-1) : width;
double ratio = static_cast<double>(iw) / ih;

// Candidate 1: floor width first
int64_t ow1 = static_cast<int64_t>(std::sqrt(max_area * ratio)) / dw * dw;
int64_t oh1 = (max_area / ow1) / dh * dh;
double ratio1 = static_cast<double>(ow1) / oh1;

// Candidate 2: floor height first
int64_t oh2 = static_cast<int64_t>(std::sqrt(max_area / ratio)) / dh * dh;
int64_t ow2 = (max_area / oh2) / dw * dw;
double ratio2 = static_cast<double>(ow2) / oh2;

// Pick the candidate that preserves aspect ratio better
int64_t calc_height, calc_width;
if (std::max(ratio / ratio1, ratio1 / ratio) <
std::max(ratio / ratio2, ratio2 / ratio)) {
calc_width = ow1;
calc_height = oh1;
} else {
calc_width = ow2;
calc_height = oh2;
}

if (height != calc_height || width != calc_width) {
height = calc_height;
width = calc_width;
LOG(INFO) << "Size adjusted by aspect ratio: height=" << height
<< ", width=" << width;
}
}
}

private:
UniPCMultistepScheduler scheduler_{nullptr};
AutoencoderKLWan vae_{nullptr};
Expand Down