diff --git a/diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_binding.cpp b/diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_binding.cpp new file mode 100644 index 0000000..c08f52e --- /dev/null +++ b/diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_binding.cpp @@ -0,0 +1,73 @@ +#include + +torch::Tensor rotary_emb_forward_cuda(torch::Tensor x, torch::Tensor freqs_cis); +std::vector modulate_forward_cuda(torch::Tensor x, torch::Tensor mod_params); +std::vector modulate_indexed_forward_cuda( + torch::Tensor x, + torch::Tensor mod_params, + torch::Tensor index +); + +torch::Tensor rotary_emb_forward(torch::Tensor x, torch::Tensor freqs_cis) { + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(freqs_cis.is_cuda(), "freqs_cis must be a CUDA tensor"); + TORCH_CHECK(x.is_contiguous(), "x must be contiguous"); + TORCH_CHECK(freqs_cis.is_contiguous(), "freqs_cis must be contiguous"); + TORCH_CHECK(x.dim() == 4, "x must have shape [B, S, H, D]"); + TORCH_CHECK(freqs_cis.dim() == 2, "freqs_cis must have shape [S, D/2]"); + TORCH_CHECK(x.size(1) == freqs_cis.size(0), "x sequence dimension must match freqs_cis"); + TORCH_CHECK((x.size(3) % 2) == 0, "last dim of x must be even"); + TORCH_CHECK(freqs_cis.scalar_type() == at::kComplexFloat, "freqs_cis must be complex64"); + TORCH_CHECK( + (x.scalar_type() == at::kFloat) || (x.scalar_type() == at::kHalf) || (x.scalar_type() == at::kBFloat16), + "x dtype must be float32, float16, or bfloat16" + ); + TORCH_CHECK(x.size(3) / 2 == freqs_cis.size(1), "freqs_cis second dim must be D/2"); + return rotary_emb_forward_cuda(x, freqs_cis); +} + +std::vector modulate_forward(torch::Tensor x, torch::Tensor mod_params) { + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(mod_params.is_cuda(), "mod_params must be a CUDA tensor"); + TORCH_CHECK(x.is_contiguous(), "x must be contiguous"); + TORCH_CHECK(mod_params.is_contiguous(), "mod_params must be contiguous"); + TORCH_CHECK(x.dim() == 3, "x must have shape [B, S, D]"); + TORCH_CHECK(mod_params.dim() == 2, "mod_params must have shape [B, 3D]"); + TORCH_CHECK(x.size(0) == mod_params.size(0), "mod_params batch must equal x batch"); + TORCH_CHECK(mod_params.size(1) == x.size(2) * 3, "mod_params last dim must be 3 * D"); + TORCH_CHECK(mod_params.scalar_type() == x.scalar_type(), "mod_params dtype must match x dtype"); + TORCH_CHECK( + (x.scalar_type() == at::kFloat) || (x.scalar_type() == at::kHalf) || (x.scalar_type() == at::kBFloat16), + "x dtype must be float32, float16, or bfloat16" + ); + return modulate_forward_cuda(x, mod_params); +} + +std::vector modulate_indexed_forward(torch::Tensor x, torch::Tensor mod_params, torch::Tensor index) { + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(mod_params.is_cuda(), "mod_params must be a CUDA tensor"); + TORCH_CHECK(index.is_cuda(), "index must be a CUDA tensor"); + TORCH_CHECK(x.is_contiguous(), "x must be contiguous"); + TORCH_CHECK(mod_params.is_contiguous(), "mod_params must be contiguous"); + TORCH_CHECK(index.is_contiguous(), "index must be contiguous"); + TORCH_CHECK(x.dim() == 3, "x must have shape [B, S, D]"); + TORCH_CHECK(mod_params.dim() == 2, "mod_params must have shape [2B, 3D]"); + TORCH_CHECK(index.dim() == 3, "index must have shape [B or 1, S, 1]"); + TORCH_CHECK(index.size(1) == x.size(1), "index sequence dim must match x sequence dim"); + TORCH_CHECK(index.size(2) == 1, "index last dim must be 1"); + TORCH_CHECK(mod_params.size(0) == x.size(0) * 2, "mod_params batch must be 2 * x batch"); + TORCH_CHECK(mod_params.size(1) == x.size(2) * 3, "mod_params last dim must be 3 * D"); + TORCH_CHECK(mod_params.scalar_type() == x.scalar_type(), "mod_params dtype must match x dtype"); + TORCH_CHECK(index.scalar_type() == at::kLong || index.scalar_type() == at::kInt, "index must be int32 or int64"); + TORCH_CHECK( + (x.scalar_type() == at::kFloat) || (x.scalar_type() == at::kHalf) || (x.scalar_type() == at::kBFloat16), + "x dtype must be float32, float16, or bfloat16" + ); + return modulate_indexed_forward_cuda(x, mod_params, index); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rotary_emb_forward", &rotary_emb_forward, "Qwen rotary embedding forward (CUDA)"); + m.def("modulate_forward", &modulate_forward, "Qwen modulation forward (CUDA)"); + m.def("modulate_indexed_forward", &modulate_indexed_forward, "Qwen indexed modulation forward (CUDA)"); +} diff --git a/diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_kernel.cu b/diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_kernel.cu new file mode 100644 index 0000000..1ca7e12 --- /dev/null +++ b/diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_kernel.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include + +#include +#include + +template +__global__ void rotary_emb_kernel( + const scalar_t* __restrict__ x, + const c10::complex* __restrict__ freqs, + scalar_t* __restrict__ out, + int64_t bsz, + int64_t seq, + int64_t heads, + int64_t dim +) { + const int64_t half_dim = dim >> 1; + const int64_t total_pairs = bsz * seq * heads * half_dim; + const int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= total_pairs) return; + + int64_t t = idx; + const int64_t pair_idx = t % half_dim; + t /= half_dim; + const int64_t head_idx = t % heads; + t /= heads; + const int64_t seq_idx = t % seq; + const int64_t batch_idx = t / seq; + + const int64_t base = ((batch_idx * seq + seq_idx) * heads + head_idx) * dim + (pair_idx << 1); + const float x0 = static_cast(x[base]); + const float x1 = static_cast(x[base + 1]); + const c10::complex f = freqs[seq_idx * half_dim + pair_idx]; + const float fr = f.real(); + const float fi = f.imag(); + + const float y0 = x0 * fr - x1 * fi; + const float y1 = x0 * fi + x1 * fr; + out[base] = static_cast(y0); + out[base + 1] = static_cast(y1); +} + +torch::Tensor rotary_emb_forward_cuda(torch::Tensor x, torch::Tensor freqs_cis) { + auto out = torch::empty_like(x); + const int64_t bsz = x.size(0); + const int64_t seq = x.size(1); + const int64_t heads = x.size(2); + const int64_t dim = x.size(3); + const int64_t total_pairs = bsz * seq * heads * (dim >> 1); + + if (total_pairs == 0) { + return out; + } + + constexpr int threads = 256; + const int blocks = static_cast((total_pairs + threads - 1) / threads); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + x.scalar_type(), + "qwen_image_rotary_emb_cuda", + [&] { + rotary_emb_kernel<<>>( + x.data_ptr(), + freqs_cis.data_ptr>(), + out.data_ptr(), + bsz, + seq, + heads, + dim + ); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return out; +} + + +template +__global__ void modulate_kernel( + const scalar_t* __restrict__ x, + const scalar_t* __restrict__ shift, + const scalar_t* __restrict__ scale, + scalar_t* __restrict__ modulated, + int64_t bsz, + int64_t seq, + int64_t dim +) { + const int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t total = bsz * seq * dim; + if (idx >= total) return; + + const int64_t d = idx % dim; + const int64_t t = idx / dim; + const int64_t b = t / seq; + const int64_t base_vec = b * dim + d; + + const scalar_t xv = x[idx]; + const scalar_t shiftv = shift[base_vec]; + const scalar_t scalev = scale[base_vec]; + modulated[idx] = xv * (static_cast(1.0f) + scalev) + shiftv; +} + + +template +__global__ void modulate_indexed_kernel( + const scalar_t* __restrict__ x, + const scalar_t* __restrict__ shift, + const scalar_t* __restrict__ scale, + const scalar_t* __restrict__ gate, + const index_t* __restrict__ index, + scalar_t* __restrict__ modulated, + scalar_t* __restrict__ gate_out, + int64_t bsz, + int64_t seq, + int64_t dim, + int64_t index_bsz +) { + const int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t total = bsz * seq * dim; + if (idx >= total) return; + + const int64_t d = idx % dim; + const int64_t t = idx / dim; + const int64_t s = t % seq; + const int64_t b = t / seq; + + const int64_t ib = (index_bsz == 1) ? 0 : b; + const int64_t index_offset = (ib * seq + s); + const int64_t select = static_cast(index[index_offset]) == 0 ? b : (b + bsz); + const int64_t base_vec = select * dim + d; + + const scalar_t xv = x[idx]; + const scalar_t shiftv = shift[base_vec]; + const scalar_t scalev = scale[base_vec]; + modulated[idx] = xv * (static_cast(1.0f) + scalev) + shiftv; + gate_out[idx] = gate[base_vec]; +} + + +std::vector modulate_forward_cuda(torch::Tensor x, torch::Tensor mod_params) { + const int64_t bsz = x.size(0); + const int64_t seq = x.size(1); + const int64_t dim = x.size(2); + auto modulated = torch::empty_like(x); + + const int64_t total = bsz * seq * dim; + if (total == 0) { + auto gate_out = mod_params.narrow(1, dim * 2, dim).unsqueeze(1).contiguous(); + return {modulated, gate_out}; + } + + auto shift = mod_params.narrow(1, 0, dim).contiguous(); + auto scale = mod_params.narrow(1, dim, dim).contiguous(); + constexpr int threads = 256; + const int blocks = static_cast((total + threads - 1) / threads); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + x.scalar_type(), + "qwen_image_modulate_cuda", + [&] { + modulate_kernel<<>>( + x.data_ptr(), + shift.data_ptr(), + scale.data_ptr(), + modulated.data_ptr(), + bsz, + seq, + dim + ); + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + auto gate_out = mod_params.narrow(1, dim * 2, dim).unsqueeze(1).contiguous(); + return {modulated, gate_out}; +} + + +std::vector modulate_indexed_forward_cuda( + torch::Tensor x, + torch::Tensor mod_params, + torch::Tensor index +) { + const int64_t bsz = x.size(0); + const int64_t seq = x.size(1); + const int64_t dim = x.size(2); + const int64_t index_bsz = index.size(0); + auto modulated = torch::empty_like(x); + auto gate_out = torch::empty_like(x); + + const int64_t total = bsz * seq * dim; + if (total == 0) { + return {modulated, gate_out}; + } + + auto shift = mod_params.narrow(1, 0, dim).contiguous(); + auto scale = mod_params.narrow(1, dim, dim).contiguous(); + auto gate = mod_params.narrow(1, dim * 2, dim).contiguous(); + auto index_2d = index.squeeze(-1); + constexpr int threads = 256; + const int blocks = static_cast((total + threads - 1) / threads); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, + at::kBFloat16, + x.scalar_type(), + "qwen_image_modulate_indexed_cuda", + [&] { + if (index_2d.scalar_type() == at::kLong) { + modulate_indexed_kernel<<>>( + x.data_ptr(), + shift.data_ptr(), + scale.data_ptr(), + gate.data_ptr(), + index_2d.data_ptr(), + modulated.data_ptr(), + gate_out.data_ptr(), + bsz, + seq, + dim, + index_bsz + ); + } else { + modulate_indexed_kernel<<>>( + x.data_ptr(), + shift.data_ptr(), + scale.data_ptr(), + gate.data_ptr(), + index_2d.data_ptr(), + modulated.data_ptr(), + gate_out.data_ptr(), + bsz, + seq, + dim, + index_bsz + ); + } + } + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + return {modulated, gate_out}; +} diff --git a/diffsynth_engine/models/qwen_image/qwen_image_cuda_ext.py b/diffsynth_engine/models/qwen_image/qwen_image_cuda_ext.py new file mode 100644 index 0000000..57229fa --- /dev/null +++ b/diffsynth_engine/models/qwen_image/qwen_image_cuda_ext.py @@ -0,0 +1,101 @@ +import os +import shutil +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.cpp_extension import load + +_EXTENSION_NAME = "qwen_image_cuda_ext_v3" + + +def _sources() -> list[str]: + base_dir = Path(__file__).resolve().parent / "csrc" + return [ + str(base_dir / "qwen_image_rotary_binding.cpp"), + str(base_dir / "qwen_image_rotary_kernel.cu"), + ] + + +def _extension_arch_list() -> str: + # Build for the local visible GPU arch directly to avoid PTX JIT/toolchain + # mismatches, while also bypassing problematic global TORCH_CUDA_ARCH_LIST. + try: + major, minor = torch.cuda.get_device_capability(0) + except Exception: + return "9.0" + return f"{major}.{minor}" + + +def _preferred_host_compilers() -> tuple[Optional[str], Optional[str]]: + # Prefer system GCC/G++ over conda-forge GCC to avoid nvcc host-compiler + # compatibility errors in some conda environments. + cc = "/usr/bin/gcc" if Path("/usr/bin/gcc").exists() else shutil.which("gcc") + cxx = "/usr/bin/g++" if Path("/usr/bin/g++").exists() else shutil.which("g++") + return cc, cxx + + +@lru_cache(maxsize=1) +def _load_extension(): + if not torch.cuda.is_available(): + return None + + old_arch = os.getenv("TORCH_CUDA_ARCH_LIST") + os.environ["TORCH_CUDA_ARCH_LIST"] = _extension_arch_list() + old_cc = os.getenv("CC") + old_cxx = os.getenv("CXX") + cc, cxx = _preferred_host_compilers() + if cc: + os.environ["CC"] = cc + if cxx: + os.environ["CXX"] = cxx + + try: + return load( + name=_EXTENSION_NAME, + sources=_sources(), + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=["-O3", "--use_fast_math"], + verbose=os.getenv("QWEN_IMAGE_CUDA_EXT_VERBOSE", "0") == "1", + ) + except Exception as err: + if os.getenv("QWEN_IMAGE_CUDA_EXT_WARN", "1") == "1": + print(f"[QwenImage CUDA] rotary extension disabled: {err}") + return None + finally: + if old_arch is None: + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = old_arch + if old_cc is None: + os.environ.pop("CC", None) + else: + os.environ["CC"] = old_cc + if old_cxx is None: + os.environ.pop("CXX", None) + else: + os.environ["CXX"] = old_cxx + + +def rotary_emb_forward(x: torch.Tensor, freqs_cis: torch.Tensor) -> Optional[torch.Tensor]: + ext = _load_extension() + if ext is None: + return None + return ext.rotary_emb_forward(x, freqs_cis) + + +def modulate_forward(x: torch.Tensor, mod_params: torch.Tensor) -> Optional[tuple[torch.Tensor, torch.Tensor]]: + ext = _load_extension() + if ext is None: + return None + return ext.modulate_forward(x, mod_params) + + +def modulate_indexed_forward( + x: torch.Tensor, mod_params: torch.Tensor, index: torch.Tensor +) -> Optional[tuple[torch.Tensor, torch.Tensor]]: + ext = _load_extension() + if ext is None: + return None + return ext.modulate_indexed_forward(x, mod_params, index) diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit.py b/diffsynth_engine/models/qwen_image/qwen_image_dit.py index 9b5933c..bb7e2e7 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from dataclasses import dataclass from typing import Any, Dict, List, Tuple, Union, Optional from einops import rearrange from math import prod @@ -161,6 +162,13 @@ def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[ return x_out.type_as(x) +@dataclass +class ImageTokenCache: + static_indices: torch.Tensor + img_k_static: torch.Tensor + img_v_static: torch.Tensor + + class QwenDoubleStreamAttention(nn.Module): def __init__( self, @@ -190,6 +198,72 @@ def __init__( self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) + def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor: + return rearrange(x, "b s (h d) -> b s h d", h=self.num_heads) + + def project_image_qkv(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + img_q = self._reshape_heads(self.to_q(image)) + img_k = self._reshape_heads(self.to_k(image)) + img_v = self._reshape_heads(self.to_v(image)) + return img_q, img_k, img_v + + def project_image_kv(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + img_k = self._reshape_heads(self.to_k(image)) + img_v = self._reshape_heads(self.to_v(image)) + img_k = self.norm_k(img_k) + return img_k, img_v + + def project_text_qkv(self, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + txt_q = self._reshape_heads(self.add_q_proj(text)) + txt_k = self._reshape_heads(self.add_k_proj(text)) + txt_v = self._reshape_heads(self.add_v_proj(text)) + return txt_q, txt_k, txt_v + + def normalize_qk( + self, img_q: torch.Tensor, img_k: torch.Tensor, txt_q: torch.Tensor, txt_k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) + txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) + return img_q, img_k, txt_q, txt_k + + def apply_rotary( + self, + img_q: torch.Tensor, + img_k: torch.Tensor, + txt_q: torch.Tensor, + txt_k: torch.Tensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if rotary_emb is None: + return img_q, img_k, txt_q, txt_k + img_freqs, txt_freqs = rotary_emb + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + return img_q, img_k, txt_q, txt_k + + def apply_image_rotary( + self, + img_q: Optional[torch.Tensor], + img_k: torch.Tensor, + img_freqs: torch.Tensor, + token_indices: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + if token_indices is not None: + img_freqs = img_freqs.index_select(0, token_indices) + if img_q is not None: + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + return img_q, img_k + + def apply_text_rotary( + self, txt_q: torch.Tensor, txt_k: torch.Tensor, txt_freqs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + return txt_q, txt_k + def forward( self, image: torch.FloatTensor, @@ -198,26 +272,10 @@ def forward( attn_mask: Optional[torch.Tensor] = None, attn_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) - txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) - - img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads) - img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads) - img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads) - - txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads) - txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads) - txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads) - - img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) - txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) - - if rotary_emb is not None: - img_freqs, txt_freqs = rotary_emb - img_q = apply_rotary_emb_qwen(img_q, img_freqs) - img_k = apply_rotary_emb_qwen(img_k, img_freqs) - txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) - txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + img_q, img_k, img_v = self.project_image_qkv(image) + txt_q, txt_k, txt_v = self.project_text_qkv(text) + img_q, img_k, txt_q, txt_k = self.normalize_qk(img_q, img_k, txt_q, txt_k) + img_q, img_k, txt_q, txt_k = self.apply_rotary(img_q, img_k, txt_q, txt_k, rotary_emb) joint_q = torch.cat([txt_q, img_q], dim=1) joint_k = torch.cat([txt_k, img_k], dim=1) @@ -236,6 +294,57 @@ def forward( return img_attn_output, txt_attn_output + def forward_with_cached_image_kv( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + cached_img_k: torch.FloatTensor, + cached_img_v: torch.FloatTensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + image_token_indices: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if image_token_indices is None: + image_token_indices = torch.arange(image.shape[1], device=image.device, dtype=torch.long) + + img_q, img_k, img_v = self.project_image_qkv(image) + txt_q, txt_k, txt_v = self.project_text_qkv(text) + img_q, img_k, txt_q, txt_k = self.normalize_qk(img_q, img_k, txt_q, txt_k) + + if rotary_emb is not None: + img_freqs, txt_freqs = rotary_emb + img_q, img_k = self.apply_image_rotary( + img_q=img_q, + img_k=img_k, + img_freqs=img_freqs, + token_indices=image_token_indices, + ) + txt_q, txt_k = self.apply_text_rotary(txt_q, txt_k, txt_freqs) + + joint_q = torch.cat([txt_q, img_q], dim=1) + joint_k = torch.cat([txt_k, img_k, cached_img_k], dim=1) + joint_v = torch.cat([txt_v, img_v, cached_img_v], dim=1) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + attn_mask_dyn = attn_mask + if attn_mask is not None: + txt_len = text.shape[1] + query_indices = torch.cat( + [ + torch.arange(txt_len, device=image.device, dtype=torch.long), + txt_len + image_token_indices, + ], + dim=0, + ) + attn_mask_dyn = attn_mask.index_select(2, query_indices) + joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask_dyn, **attn_kwargs) + joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype) + + txt_attn_output = self.to_add_out(joint_attn_out[:, : text.shape[1], :]) + img_attn_output = self.to_out(joint_attn_out[:, text.shape[1] :, :]) + return img_attn_output, txt_attn_output + class QwenImageTransformerBlock(nn.Module): def __init__( @@ -278,6 +387,83 @@ def __init__( self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype) self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim, device=device, dtype=dtype) self.zero_cond_t = zero_cond_t + self._image_token_cache: Optional[ImageTokenCache] = None + + def clear_image_token_cache(self): + self._image_token_cache = None + + def _resolve_cache_state( + self, modulate_index: Optional[torch.Tensor], use_image_token_cache: bool + ) -> Tuple[Optional[torch.Tensor], bool, bool, Optional[int]]: + static_indices = self._get_static_token_indices(modulate_index) if use_image_token_cache else None + use_static_cache = use_image_token_cache and static_indices is not None + cache_ready = ( + use_static_cache + and self._image_token_cache is not None + and self._image_token_cache.static_indices.shape == static_indices.shape + and torch.equal(self._image_token_cache.static_indices, static_indices) + ) + dynamic_seq_len = None + if use_static_cache: + dynamic_seq_len = int((modulate_index[0].squeeze(-1) == 0).sum().item()) + return static_indices, use_static_cache, cache_ready, dynamic_seq_len + + def _prepare_cached_image(self, image: torch.Tensor, cache_ready: bool, dynamic_seq_len: Optional[int]) -> torch.Tensor: + if cache_ready and dynamic_seq_len is not None and image.shape[1] != dynamic_seq_len: + return image[:, :dynamic_seq_len, :] + return image + + def _align_modulate_index_to_image( + self, modulate_index: Optional[torch.Tensor], image: torch.Tensor + ) -> Optional[torch.Tensor]: + if modulate_index is not None and modulate_index.shape[1] != image.shape[1]: + # Cached path runs on dynamic-only tokens, which are the prefix (all zero-valued index). + return modulate_index[:, : image.shape[1], :] + return modulate_index + + def _apply_img_mlp_residual( + self, image: torch.Tensor, img_mod_mlp: torch.Tensor, img_modulate_index: Optional[torch.Tensor] + ) -> torch.Tensor: + img_normed_2 = self.img_norm2(image) + img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, img_modulate_index) + img_mlp_out = self.img_mlp(img_modulated_2) + return image + img_gate_2 * img_mlp_out + + def _apply_txt_mlp_residual(self, text: torch.Tensor, txt_mod_mlp: torch.Tensor) -> torch.Tensor: + txt_normed_2 = self.txt_norm2(text) + txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + txt_mlp_out = self.txt_mlp(txt_modulated_2) + return text + txt_gate_2 * txt_mlp_out + + def _get_static_token_indices(self, modulate_index: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if modulate_index is None: + return None + index = modulate_index + if index.dim() == 3: + index = index.squeeze(-1) + # `modulate_index` is shared across the batch; use the first sample mask. + static_mask = index[0] == 1 + if not torch.any(static_mask): + return None + return torch.nonzero(static_mask, as_tuple=False).squeeze(-1) + + def _build_static_kv_cache( + self, + img_modulated: torch.Tensor, + static_indices: torch.Tensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_static = img_modulated.index_select(1, static_indices) + img_k_static, img_v_static = self.attn.project_image_kv(img_static) + if rotary_emb is not None: + img_freqs, _ = rotary_emb + _, img_k_static = self.attn.apply_image_rotary( + img_q=None, + img_k=img_k_static, + img_freqs=img_freqs, + token_indices=static_indices, + ) + return img_k_static, img_v_static def _modulate(self, x, mod_params, index=None): shift, scale, gate = mod_params.chunk(3, dim=-1) @@ -310,39 +496,66 @@ def forward( attn_mask: Optional[torch.Tensor] = None, attn_kwargs: Optional[Dict[str, Any]] = None, modulate_index: Optional[List[int]] = None, + use_image_token_cache: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: + static_indices, use_static_cache, cache_ready, dynamic_seq_len = self._resolve_cache_state( + modulate_index, use_image_token_cache + ) + image = self._prepare_cached_image(image, cache_ready, dynamic_seq_len) + img_modulate_index = self._align_modulate_index_to_image(modulate_index, image) + img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each if self.zero_cond_t: temb = torch.chunk(temb, 2, dim=0)[0] txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each img_normed = self.img_norm1(image) - img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, modulate_index) + img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, img_modulate_index) txt_normed = self.txt_norm1(text) txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) - img_attn_out, txt_attn_out = self.attn( - image=img_modulated, - text=txt_modulated, - rotary_emb=rotary_emb, - attn_mask=attn_mask, - attn_kwargs=attn_kwargs, - ) - image = image + img_gate * img_attn_out - text = text + txt_gate * txt_attn_out - - img_normed_2 = self.img_norm2(image) - img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, modulate_index) - - txt_normed_2 = self.txt_norm2(text) - txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + if not cache_ready: + img_attn_out, txt_attn_out = self.attn( + image=img_modulated, + text=txt_modulated, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, + ) - img_mlp_out = self.img_mlp(img_modulated_2) - txt_mlp_out = self.txt_mlp(txt_modulated_2) + image = image + img_gate * img_attn_out + text = text + txt_gate * txt_attn_out + image = self._apply_img_mlp_residual(image, img_mod_mlp, img_modulate_index) + text = self._apply_txt_mlp_residual(text, txt_mod_mlp) + + if use_static_cache: + img_k_static, img_v_static = self._build_static_kv_cache(img_modulated, static_indices, rotary_emb) + self._image_token_cache = ImageTokenCache( + static_indices=static_indices.detach().clone(), + img_k_static=img_k_static.detach().clone(), + img_v_static=img_v_static.detach().clone(), + ) + else: + cached_k_static = self._image_token_cache.img_k_static + cached_v_static = self._image_token_cache.img_v_static + + dynamic_indices = torch.arange(image.shape[1], device=image.device, dtype=torch.long) + img_attn_dyn_out, txt_attn_out = self.attn.forward_with_cached_image_kv( + image=img_modulated, + text=txt_modulated, + cached_img_k=cached_k_static, + cached_img_v=cached_v_static, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, + image_token_indices=dynamic_indices, + ) + text = text + txt_gate * txt_attn_out - image = image + img_gate_2 * img_mlp_out - text = text + txt_gate_2 * txt_mlp_out + image = image + img_gate * img_attn_dyn_out + image = self._apply_img_mlp_residual(image, img_mod_mlp, img_modulate_index) + text = self._apply_txt_mlp_residual(text, txt_mod_mlp) return text, image @@ -355,6 +568,7 @@ def __init__( self, num_layers: int = 60, zero_cond_t: bool = False, + use_image_token_cache: bool = False, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -385,6 +599,7 @@ def __init__( self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype) self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype) self.zero_cond_t = zero_cond_t + self.use_image_token_cache = use_image_token_cache def patchify(self, hidden_states): hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) @@ -537,6 +752,9 @@ def forward( img_freqs, txt_freqs = rotary_emb with sequence_parallel((image, text, img_freqs, txt_freqs, modulate_index), seq_dims=(1, 1, 0, 0, 1)): rotary_emb = (img_freqs, txt_freqs) + # Cache decision is per denoising step, but the KV cache itself lives per block. + # We pass the flag into every block so each block can decide whether to reuse + # the static-image KV part (non-dynamic tokens) or run full attention. for block in self.transformer_blocks: text, image = block( image=image, @@ -546,6 +764,8 @@ def forward( attn_mask=attn_mask, attn_kwargs=attn_kwargs, modulate_index=modulate_index, + # Controls block-level static image KV reuse path. + use_image_token_cache=self.use_image_token_cache, ) if self.zero_cond_t: conditioning = conditioning.chunk(2, dim=0)[0] @@ -577,5 +797,14 @@ def compile_repeated_blocks(self, *args, **kwargs): for block in self.transformer_blocks: block.compile(*args, **kwargs) + def clear_image_token_caches(self): + for block in self.transformer_blocks: + block.clear_image_token_cache() + + def set_image_token_cache_enabled(self, enabled: bool, clear_existing_cache: bool = False): + self.use_image_token_cache = enabled + if clear_existing_cache: + self.clear_image_token_caches() + def get_fsdp_module_cls(self): - return {QwenImageTransformerBlock} + return {QwenImageTransformerBlock} \ No newline at end of file diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit_cuda.py b/diffsynth_engine/models/qwen_image/qwen_image_dit_cuda.py new file mode 100644 index 0000000..e66572d --- /dev/null +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit_cuda.py @@ -0,0 +1,862 @@ +import torch +import torch.nn as nn +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union, Optional +from einops import rearrange +from math import prod + +from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel +from diffsynth_engine.models.basic import attention as attention_ops +from diffsynth_engine.models.basic.timestep import TimestepEmbeddings +from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm +from diffsynth_engine.models.qwen_image.qwen_image_cuda_ext import ( + modulate_forward as modulate_forward_cuda, + modulate_indexed_forward as modulate_indexed_forward_cuda, + rotary_emb_forward as rotary_emb_forward_cuda, +) + +from diffsynth_engine.utils.gguf import gguf_inference +from diffsynth_engine.utils.fp8_linear import fp8_inference +from diffsynth_engine.utils.parallel import ( + cfg_parallel, + cfg_parallel_unshard, + sequence_parallel, + sequence_parallel_unshard, +) + + +class QwenImageDiTStateDictConverter(StateDictConverter): + def __init__(self): + pass + + def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + state_dict_ = {} + dim = 3072 + for name, param in state_dict.items(): + name_ = name + if name.startswith("transformer") and "attn.to_out.0" in name: + name_ = name.replace("attn.to_out.0", "attn.to_out") + if "timestep_embedder.linear_1" in name: + name_ = name.replace("timestep_embedder.linear_1", "timestep_embedder.0") + if "timestep_embedder.linear_2" in name: + name_ = name.replace("timestep_embedder.linear_2", "timestep_embedder.2") + if "norm_out.linear" in name: + param = torch.concat([param[dim:], param[:dim]], dim=0) + state_dict_[name_] = param + return state_dict_ + + def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + state_dict = self._from_diffusers(state_dict) + return state_dict + + +class QwenEmbedRope(nn.Module): + def __init__( + self, + theta: int, + axes_dim: list[int], + scale_rope=False, + device: str = "cuda:0", + ): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + with torch.device("cpu" if device == "meta" else device): + pos_index = torch.arange(10000) + neg_index = torch.arange(10000).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.rope_cache = {} + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward(self, video_fhw, txt_length, device): + """ + Args: + video_fhw (List[Tuple[int, int, int]]): A list of (frame, height, width) tuples for each video/image + txt_length (int): The maximum length of the text sequences + """ + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + +class QwenFeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dropout: float = 0.0, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + inner_dim = int(dim * 4) + self.net = nn.ModuleList([]) + self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype)) + self.net.append(nn.Dropout(dropout)) + self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]): + if ( + isinstance(freqs_cis, torch.Tensor) + and x.is_cuda + and freqs_cis.is_cuda + and x.is_contiguous() + and freqs_cis.is_contiguous() + and x.dim() == 4 + and freqs_cis.dim() == 2 + and x.shape[1] == freqs_cis.shape[0] + and x.shape[-1] % 2 == 0 + and freqs_cis.dtype == torch.complex64 + ): + x_out = rotary_emb_forward_cuda(x, freqs_cis) + if x_out is not None: + return x_out + + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2) + x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d) + return x_out.type_as(x) + + +@dataclass +class ImageTokenCache: + static_indices: torch.Tensor + img_k_static: torch.Tensor + img_v_static: torch.Tensor + + +class QwenDoubleStreamAttention(nn.Module): + def __init__( + self, + dim_a, + dim_b, + num_heads, + head_dim, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) + self.to_k = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) + self.to_v = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) + self.norm_q = RMSNorm(head_dim, eps=1e-6, device=device, dtype=dtype) + self.norm_k = RMSNorm(head_dim, eps=1e-6, device=device, dtype=dtype) + + self.add_q_proj = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) + self.add_k_proj = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) + self.add_v_proj = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) + self.norm_added_q = RMSNorm(head_dim, eps=1e-6, device=device, dtype=dtype) + self.norm_added_k = RMSNorm(head_dim, eps=1e-6, device=device, dtype=dtype) + + self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) + self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) + + def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor: + return rearrange(x, "b s (h d) -> b s h d", h=self.num_heads) + + def project_image_qkv(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + img_q = self._reshape_heads(self.to_q(image)) + img_k = self._reshape_heads(self.to_k(image)) + img_v = self._reshape_heads(self.to_v(image)) + return img_q, img_k, img_v + + def project_image_kv(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + img_k = self._reshape_heads(self.to_k(image)) + img_v = self._reshape_heads(self.to_v(image)) + img_k = self.norm_k(img_k) + return img_k, img_v + + def project_text_qkv(self, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + txt_q = self._reshape_heads(self.add_q_proj(text)) + txt_k = self._reshape_heads(self.add_k_proj(text)) + txt_v = self._reshape_heads(self.add_v_proj(text)) + return txt_q, txt_k, txt_v + + def normalize_qk( + self, img_q: torch.Tensor, img_k: torch.Tensor, txt_q: torch.Tensor, txt_k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + img_q, img_k = self.norm_q(img_q), self.norm_k(img_k) + txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k) + return img_q, img_k, txt_q, txt_k + + def apply_rotary( + self, + img_q: torch.Tensor, + img_k: torch.Tensor, + txt_q: torch.Tensor, + txt_k: torch.Tensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if rotary_emb is None: + return img_q, img_k, txt_q, txt_k + img_freqs, txt_freqs = rotary_emb + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + return img_q, img_k, txt_q, txt_k + + def apply_image_rotary( + self, + img_q: Optional[torch.Tensor], + img_k: torch.Tensor, + img_freqs: torch.Tensor, + token_indices: Optional[torch.Tensor] = None, + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + if token_indices is not None: + img_freqs = img_freqs.index_select(0, token_indices) + if img_q is not None: + img_q = apply_rotary_emb_qwen(img_q, img_freqs) + img_k = apply_rotary_emb_qwen(img_k, img_freqs) + return img_q, img_k + + def apply_text_rotary( + self, txt_q: torch.Tensor, txt_k: torch.Tensor, txt_freqs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs) + txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs) + return txt_q, txt_k + + def forward( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + img_q, img_k, img_v = self.project_image_qkv(image) + txt_q, txt_k, txt_v = self.project_text_qkv(text) + img_q, img_k, txt_q, txt_k = self.normalize_qk(img_q, img_k, txt_q, txt_k) + img_q, img_k, txt_q, txt_k = self.apply_rotary(img_q, img_k, txt_q, txt_k, rotary_emb) + + joint_q = torch.cat([txt_q, img_q], dim=1) + joint_k = torch.cat([txt_k, img_k], dim=1) + joint_v = torch.cat([txt_v, img_v], dim=1) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs) + + joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype) + + txt_attn_output = joint_attn_out[:, : text.shape[1], :] + img_attn_output = joint_attn_out[:, text.shape[1] :, :] + + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + def forward_with_cached_image_kv( + self, + image: torch.FloatTensor, + text: torch.FloatTensor, + cached_img_k: torch.FloatTensor, + cached_img_v: torch.FloatTensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + image_token_indices: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if image_token_indices is None: + image_token_indices = torch.arange(image.shape[1], device=image.device, dtype=torch.long) + + img_q, img_k, img_v = self.project_image_qkv(image) + txt_q, txt_k, txt_v = self.project_text_qkv(text) + img_q, img_k, txt_q, txt_k = self.normalize_qk(img_q, img_k, txt_q, txt_k) + + if rotary_emb is not None: + img_freqs, txt_freqs = rotary_emb + img_q, img_k = self.apply_image_rotary( + img_q=img_q, + img_k=img_k, + img_freqs=img_freqs, + token_indices=image_token_indices, + ) + txt_q, txt_k = self.apply_text_rotary(txt_q, txt_k, txt_freqs) + + joint_q = torch.cat([txt_q, img_q], dim=1) + joint_k = torch.cat([txt_k, img_k, cached_img_k], dim=1) + joint_v = torch.cat([txt_v, img_v, cached_img_v], dim=1) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + attn_mask_dyn = attn_mask + if attn_mask is not None: + txt_len = text.shape[1] + query_indices = torch.cat( + [ + torch.arange(txt_len, device=image.device, dtype=torch.long), + txt_len + image_token_indices, + ], + dim=0, + ) + attn_mask_dyn = attn_mask.index_select(2, query_indices) + joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask_dyn, **attn_kwargs) + joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype) + + txt_attn_output = self.to_add_out(joint_attn_out[:, : text.shape[1], :]) + img_attn_output = self.to_out(joint_attn_out[:, text.shape[1] :, :]) + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + zero_cond_t: bool = False, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True, device=device, dtype=dtype), + ) + self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype) + self.attn = QwenDoubleStreamAttention( + dim_a=dim, + dim_b=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + device=device, + dtype=dtype, + ) + self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype) + self.img_mlp = QwenFeedForward(dim=dim, dim_out=dim, device=device, dtype=dtype) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 6 * dim, bias=True, device=device, dtype=dtype), + ) + self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype) + self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype) + self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim, device=device, dtype=dtype) + self.zero_cond_t = zero_cond_t + self._image_token_cache: Optional[ImageTokenCache] = None + + def clear_image_token_cache(self): + self._image_token_cache = None + + def _resolve_cache_state( + self, modulate_index: Optional[torch.Tensor], use_image_token_cache: bool + ) -> Tuple[Optional[torch.Tensor], bool, bool, Optional[int]]: + static_indices = self._get_static_token_indices(modulate_index) if use_image_token_cache else None + use_static_cache = use_image_token_cache and static_indices is not None + cache_ready = ( + use_static_cache + and self._image_token_cache is not None + and self._image_token_cache.static_indices.shape == static_indices.shape + and torch.equal(self._image_token_cache.static_indices, static_indices) + ) + dynamic_seq_len = None + if use_static_cache: + dynamic_seq_len = int((modulate_index[0].squeeze(-1) == 0).sum().item()) + return static_indices, use_static_cache, cache_ready, dynamic_seq_len + + def _prepare_cached_image(self, image: torch.Tensor, cache_ready: bool, dynamic_seq_len: Optional[int]) -> torch.Tensor: + if cache_ready and dynamic_seq_len is not None and image.shape[1] != dynamic_seq_len: + return image[:, :dynamic_seq_len, :] + return image + + def _align_modulate_index_to_image( + self, modulate_index: Optional[torch.Tensor], image: torch.Tensor + ) -> Optional[torch.Tensor]: + if modulate_index is not None and modulate_index.shape[1] != image.shape[1]: + # Cached path runs on dynamic-only tokens, which are the prefix (all zero-valued index). + return modulate_index[:, : image.shape[1], :] + return modulate_index + + def _apply_img_mlp_residual( + self, image: torch.Tensor, img_mod_mlp: torch.Tensor, img_modulate_index: Optional[torch.Tensor] + ) -> torch.Tensor: + img_normed_2 = self.img_norm2(image) + img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, img_modulate_index) + img_mlp_out = self.img_mlp(img_modulated_2) + return image + img_gate_2 * img_mlp_out + + def _apply_txt_mlp_residual(self, text: torch.Tensor, txt_mod_mlp: torch.Tensor) -> torch.Tensor: + txt_normed_2 = self.txt_norm2(text) + txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp) + txt_mlp_out = self.txt_mlp(txt_modulated_2) + return text + txt_gate_2 * txt_mlp_out + + def _get_static_token_indices(self, modulate_index: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if modulate_index is None: + return None + index = modulate_index + if index.dim() == 3: + index = index.squeeze(-1) + # `modulate_index` is shared across the batch; use the first sample mask. + static_mask = index[0] == 1 + if not torch.any(static_mask): + return None + return torch.nonzero(static_mask, as_tuple=False).squeeze(-1) + + def _build_static_kv_cache( + self, + img_modulated: torch.Tensor, + static_indices: torch.Tensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_static = img_modulated.index_select(1, static_indices) + img_k_static, img_v_static = self.attn.project_image_kv(img_static) + if rotary_emb is not None: + img_freqs, _ = rotary_emb + _, img_k_static = self.attn.apply_image_rotary( + img_q=None, + img_k=img_k_static, + img_freqs=img_freqs, + token_indices=static_indices, + ) + return img_k_static, img_v_static + + def _modulate(self, x, mod_params, index=None): + if ( + x.is_cuda + and mod_params.is_cuda + and x.is_contiguous() + and mod_params.is_contiguous() + and x.dim() == 3 + and mod_params.dim() == 2 + and mod_params.shape[1] == x.shape[2] * 3 + and x.dtype in (torch.float16, torch.bfloat16, torch.float32) + and mod_params.dtype == x.dtype + ): + if index is None and mod_params.shape[0] == x.shape[0]: + out = modulate_forward_cuda(x, mod_params) + if out is not None: + return out + + if ( + index is not None + and index.is_cuda + and index.is_contiguous() + and index.dim() == 3 + and index.shape[1] == x.shape[1] + and index.shape[2] == 1 + and mod_params.shape[0] == x.shape[0] * 2 + and index.dtype in (torch.int32, torch.int64) + ): + out = modulate_indexed_forward_cuda(x, mod_params, index) + if out is not None: + return out + + shift, scale, gate = mod_params.chunk(3, dim=-1) + if index is not None: + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + shift_0_exp = shift_0.unsqueeze(1) + shift_1_exp = shift_1.unsqueeze(1) + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + shift_result = torch.where(index == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + return x * (1 + scale_result) + shift_result, gate_result + + def forward( + self, + image: torch.Tensor, + text: torch.Tensor, + temb: torch.Tensor, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + modulate_index: Optional[List[int]] = None, + use_image_token_cache: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + static_indices, use_static_cache, cache_ready, dynamic_seq_len = self._resolve_cache_state( + modulate_index, use_image_token_cache + ) + image = self._prepare_cached_image(image, cache_ready, dynamic_seq_len) + img_modulate_index = self._align_modulate_index_to_image(modulate_index, image) + + img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + if self.zero_cond_t: + temb = torch.chunk(temb, 2, dim=0)[0] + txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each + + img_normed = self.img_norm1(image) + img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, img_modulate_index) + + txt_normed = self.txt_norm1(text) + txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn) + + if not cache_ready: + img_attn_out, txt_attn_out = self.attn( + image=img_modulated, + text=txt_modulated, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, + ) + + image = image + img_gate * img_attn_out + text = text + txt_gate * txt_attn_out + image = self._apply_img_mlp_residual(image, img_mod_mlp, img_modulate_index) + text = self._apply_txt_mlp_residual(text, txt_mod_mlp) + + if use_static_cache: + img_k_static, img_v_static = self._build_static_kv_cache(img_modulated, static_indices, rotary_emb) + self._image_token_cache = ImageTokenCache( + static_indices=static_indices.detach().clone(), + img_k_static=img_k_static.detach().clone(), + img_v_static=img_v_static.detach().clone(), + ) + else: + cached_k_static = self._image_token_cache.img_k_static + cached_v_static = self._image_token_cache.img_v_static + + dynamic_indices = torch.arange(image.shape[1], device=image.device, dtype=torch.long) + img_attn_dyn_out, txt_attn_out = self.attn.forward_with_cached_image_kv( + image=img_modulated, + text=txt_modulated, + cached_img_k=cached_k_static, + cached_img_v=cached_v_static, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, + image_token_indices=dynamic_indices, + ) + text = text + txt_gate * txt_attn_out + + image = image + img_gate * img_attn_dyn_out + image = self._apply_img_mlp_residual(image, img_mod_mlp, img_modulate_index) + text = self._apply_txt_mlp_residual(text, txt_mod_mlp) + + return text, image + + +class QwenImageDiT(PreTrainedModel): + converter = QwenImageDiTStateDictConverter() + _supports_parallelization = True + + def __init__( + self, + num_layers: int = 60, + zero_cond_t: bool = False, + use_image_token_cache: bool = False, + device: str = "cuda:0", + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device) + + self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype) + + self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype) + + self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype) + self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=3072, + num_attention_heads=24, + attention_head_dim=128, + zero_cond_t=zero_cond_t, + device=device, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype) + self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype) + self.zero_cond_t = zero_cond_t + self.use_image_token_cache = use_image_token_cache + + def patchify(self, hidden_states): + hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2) + return hidden_states + + def unpatchify(self, hidden_states, height, width): + hidden_states = rearrange( + hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height // 2, W=width // 2 + ) + return hidden_states + + def process_entity_masks( + self, + text: torch.Tensor, + text_seq_lens: torch.LongTensor, + rotary_emb: Tuple[torch.Tensor, torch.Tensor], + video_fhw: List[Tuple[int, int, int]], + entity_text: List[torch.Tensor], + entity_seq_lens: List[torch.LongTensor], + entity_masks: List[torch.Tensor], + device: str, + dtype: torch.dtype, + ): + entity_seq_lens = [seq_lens.max().item() for seq_lens in entity_seq_lens] + text_seq_lens = entity_seq_lens + [text_seq_lens.max().item()] + entity_text = [ + self.txt_in(self.txt_norm(text[:, :seq_len])) for text, seq_len in zip(entity_text, entity_seq_lens) + ] + text = torch.cat(entity_text + [text], dim=1) + + entity_txt_freqs = [self.pos_embed(video_fhw, seq_len, device)[1] for seq_len in entity_seq_lens] + img_freqs, txt_freqs = rotary_emb + txt_freqs = torch.cat(entity_txt_freqs + [txt_freqs], dim=0) + rotary_emb = (img_freqs, txt_freqs) + + global_mask = torch.ones_like(entity_masks[0], device=device, dtype=dtype) + patched_masks = [self.patchify(mask) for mask in entity_masks + [global_mask]] + batch_size, image_seq_len = patched_masks[0].shape[:2] + total_seq_len = sum(text_seq_lens) + image_seq_len + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), device=device, dtype=torch.bool) + + # text-image attention mask + img_start, img_end = sum(text_seq_lens), total_seq_len + cumsum = [0] + for seq_len in text_seq_lens: + cumsum.append(cumsum[-1] + seq_len) + for i, patched_mask in enumerate(patched_masks): + txt_start, txt_end = cumsum[i], cumsum[i + 1] + mask = torch.sum(patched_mask, dim=-1) > 0 + mask = mask.unsqueeze(1).repeat(1, text_seq_lens[i], 1) + # text-to-image attention + attention_mask[:, txt_start:txt_end, img_start:img_end] = mask + # image-to-text attention + attention_mask[:, img_start:img_end, txt_start:txt_end] = mask.transpose(1, 2) + # entity text tokens should not attend to each other + for i in range(len(text_seq_lens)): + for j in range(len(text_seq_lens)): + if i == j: + continue + i_start, i_end = cumsum[i], cumsum[i + 1] + j_start, j_end = cumsum[j], cumsum[j + 1] + attention_mask[:, i_start:i_end, j_start:j_end] = False + + attn_mask = torch.zeros_like(attention_mask, device=device, dtype=dtype) + attn_mask[~attention_mask] = -torch.inf + attn_mask = attn_mask.unsqueeze(1) + return text, rotary_emb, attn_mask + + def forward( + self, + image: torch.Tensor, + edit: torch.Tensor = None, + timestep: torch.LongTensor = None, + text: torch.Tensor = None, + text_seq_lens: torch.LongTensor = None, + context_latents: Optional[torch.Tensor] = None, + entity_text: Optional[List[torch.Tensor]] = None, + entity_seq_lens: Optional[List[torch.LongTensor]] = None, + entity_masks: Optional[List[torch.Tensor]] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, + ): + h, w = image.shape[-2:] + fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) + use_cfg = image.shape[0] > 1 + with ( + fp8_inference(fp8_linear_enabled), + gguf_inference(), + cfg_parallel( + ( + image, + *(edit if edit is not None else ()), + timestep, + text, + text_seq_lens, + *(entity_text if entity_text is not None else ()), + *(entity_seq_lens if entity_seq_lens is not None else ()), + *(entity_masks if entity_masks is not None else ()), + context_latents, + ), + use_cfg=use_cfg, + ), + ): + if self.zero_cond_t: + timestep = torch.cat([timestep, timestep * 0], dim=0) + modulate_index = None + conditioning = self.time_text_embed(timestep, image.dtype) + video_fhw = [(1, h // 2, w // 2)] # frame, height, width + text_seq_len = text_seq_lens.max().item() + image = self.patchify(image) + image_seq_len = image.shape[1] + if context_latents is not None: + context_latents = context_latents.to(dtype=image.dtype) + context_latents = self.patchify(context_latents) + image = torch.cat([image, context_latents], dim=1) + video_fhw += [(1, h // 2, w // 2)] + if edit is not None: + for img in edit: + img = img.to(dtype=image.dtype) + edit_h, edit_w = img.shape[-2:] + img = self.patchify(img) + image = torch.cat([image, img], dim=1) + video_fhw += [(1, edit_h // 2, edit_w // 2)] + if self.zero_cond_t: + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [video_fhw]], + device=timestep.device, + dtype=torch.int, + ) + modulate_index = modulate_index.unsqueeze(-1) + rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device) + + image = self.img_in(image) + text = self.txt_in(self.txt_norm(text[:, :text_seq_len])) + + attn_mask = None + if entity_text is not None: + text, rotary_emb, attn_mask = self.process_entity_masks( + text, + text_seq_lens, + rotary_emb, + video_fhw, + entity_text, + entity_seq_lens, + entity_masks, + image.device, + image.dtype, + ) + + # warning: Eligen does not work with sequence parallel because long context attention does not support attention masks + img_freqs, txt_freqs = rotary_emb + with sequence_parallel((image, text, img_freqs, txt_freqs, modulate_index), seq_dims=(1, 1, 0, 0, 1)): + rotary_emb = (img_freqs, txt_freqs) + # Cache decision is per denoising step, but the KV cache itself lives per block. + # We pass the flag into every block so each block can decide whether to reuse + # the static-image KV part (non-dynamic tokens) or run full attention. + for block in self.transformer_blocks: + text, image = block( + image=image, + text=text, + temb=conditioning, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, + modulate_index=modulate_index, + # Controls block-level static image KV reuse path. + use_image_token_cache=self.use_image_token_cache, + ) + if self.zero_cond_t: + conditioning = conditioning.chunk(2, dim=0)[0] + image = self.norm_out(image, conditioning) + image = self.proj_out(image) + (image,) = sequence_parallel_unshard((image,), seq_dims=(1,), seq_lens=(image_seq_len,)) + image = image[:, :image_seq_len] + image = self.unpatchify(image, h, w) + + (image,) = cfg_parallel_unshard((image,), use_cfg=use_cfg) + return image + + @classmethod + def from_state_dict( + cls, + state_dict: Dict[str, torch.Tensor], + device: str, + dtype: torch.dtype, + num_layers: int = 60, + use_zero_cond_t: bool = False, + ): + model = cls(device="meta", dtype=dtype, num_layers=num_layers, zero_cond_t=use_zero_cond_t) + model = model.requires_grad_(False) + model.load_state_dict(state_dict, assign=True) + model.to(device=device, dtype=dtype, non_blocking=True) + return model + + def compile_repeated_blocks(self, *args, **kwargs): + for block in self.transformer_blocks: + block.compile(*args, **kwargs) + + def clear_image_token_caches(self): + for block in self.transformer_blocks: + block.clear_image_token_cache() + + def set_image_token_cache_enabled(self, enabled: bool, clear_existing_cache: bool = False): + self.use_image_token_cache = enabled + if clear_existing_cache: + self.clear_image_token_caches() + + def get_fsdp_module_cls(self): + return {QwenImageTransformerBlock} diff --git a/diffsynth_engine/pipelines/qwen_image.py b/diffsynth_engine/pipelines/qwen_image.py index 342cb70..03e0548 100644 --- a/diffsynth_engine/pipelines/qwen_image.py +++ b/diffsynth_engine/pipelines/qwen_image.py @@ -665,6 +665,7 @@ def __call__( # single image for edit, list for edit plus(QwenImageEdit2509) input_image: List[Image.Image] | Image.Image | None = None, cfg_scale: float = 4.0, # true cfg + cfg_list: List[float] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -674,6 +675,7 @@ def __call__( # eligen entity_prompts: Optional[List[str]] = None, entity_masks: Optional[List[Image.Image]] = None, + use_image_kv_cache: Optional[bool] = None, ): assert (height is None) == (width is None), "height and width should be set together" is_edit_plus = isinstance(input_image, list) @@ -725,6 +727,20 @@ def __call__( image_latents = None self.load_models_to_device(["encoder"]) + + if cfg_list is not None: + if len(cfg_list) == 0 or max(cfg_list) <= 1: + print( f"[QwenImagePipeline] cfg_list passed but has max values less than 1 or empty, ignoring it." ) + cfg_list = None + else: + if len(cfg_list) != num_inference_steps: + print( f"[QwenImagePipeline] cfg_list length mismatch: got {len(cfg_list)}, expected {num_inference_steps}. Will expand or trim cfg_list automatically for you." ) + if len(cfg_list) < num_inference_steps: + cfg_list = cfg_list + [cfg_list[-1]] * (num_inference_steps - len(cfg_list)) + else: + cfg_list = cfg_list[:num_inference_steps] + cfg_scale = max(cfg_list) + if image_latents is not None: prompt_emb, prompt_emb_mask = self.encode_prompt_with_image( prompt, vae_images, condition_images, 1, 4096, is_edit_plus @@ -756,48 +772,62 @@ def __call__( self.model_lifecycle_finish(["encoder"]) - self.load_models_to_device(["dit"]) - hide_progress = dist.is_initialized() and dist.get_rank() != 0 - for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): - timestep = timestep.unsqueeze(0).to(dtype=self.dtype) - noise_pred = self.predict_noise_with_cfg( - latents=latents, - image_latents=image_latents, - timestep=timestep, - prompt_emb=prompt_emb, - negative_prompt_emb=negative_prompt_emb, - prompt_emb_mask=prompt_emb_mask, - negative_prompt_emb_mask=negative_prompt_emb_mask, - context_latents=context_latents, - entity_prompt_embs=entity_prompt_embs, - entity_prompt_emb_masks=entity_prompt_emb_masks, - negative_entity_prompt_embs=negative_entity_prompt_embs, - negative_entity_prompt_emb_masks=negative_entity_prompt_emb_masks, - entity_masks=entity_masks, - cfg_scale=cfg_scale, - batch_cfg=self.config.batch_cfg, + previous_cache_enabled = getattr(self.dit, "use_image_token_cache", None) + cache_override = use_image_kv_cache is not None and hasattr(self.dit, "set_image_token_cache_enabled") + if cache_override: + self.dit.set_image_token_cache_enabled(use_image_kv_cache, clear_existing_cache=True) + + try: + self.load_models_to_device(["dit"]) + hide_progress = dist.is_initialized() and dist.get_rank() != 0 + for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): + timestep = timestep.unsqueeze(0).to(dtype=self.dtype) + if cfg_list is not None: + cfg_scale = cfg_list[i] + + noise_pred = self.predict_noise_with_cfg( + latents=latents, + image_latents=image_latents, + timestep=timestep, + prompt_emb=prompt_emb, + negative_prompt_emb=negative_prompt_emb, + prompt_emb_mask=prompt_emb_mask, + negative_prompt_emb_mask=negative_prompt_emb_mask, + context_latents=context_latents, + entity_prompt_embs=entity_prompt_embs, + entity_prompt_emb_masks=entity_prompt_emb_masks, + negative_entity_prompt_embs=negative_entity_prompt_embs, + negative_entity_prompt_emb_masks=negative_entity_prompt_emb_masks, + entity_masks=entity_masks, + cfg_scale=cfg_scale, + batch_cfg=self.config.batch_cfg, + ) + # Denoise + latents = self.sampler.step(latents, noise_pred, i) + # UI + if progress_callback is not None: + progress_callback(i, len(timesteps), "DENOISING") + self.model_lifecycle_finish(["dit"]) + # Decode image + self.load_models_to_device(["vae"]) + latents = rearrange(latents, "B C H W -> B C 1 H W") + vae_output = rearrange( + self.vae.decode( + latents.to(self.vae.model.encoder.conv1.weight.dtype), + device=self.vae.model.encoder.conv1.weight.device, + tiled=self.vae_tiled, + tile_size=self.vae_tile_size, + tile_stride=self.vae_tile_stride, + )[0], + "C B H W -> B C H W", ) - # Denoise - latents = self.sampler.step(latents, noise_pred, i) - # UI - if progress_callback is not None: - progress_callback(i, len(timesteps), "DENOISING") - self.model_lifecycle_finish(["dit"]) - # Decode image - self.load_models_to_device(["vae"]) - latents = rearrange(latents, "B C H W -> B C 1 H W") - vae_output = rearrange( - self.vae.decode( - latents.to(self.vae.model.encoder.conv1.weight.dtype), - device=self.vae.model.encoder.conv1.weight.device, - tiled=self.vae_tiled, - tile_size=self.vae_tile_size, - tile_stride=self.vae_tile_stride, - )[0], - "C B H W -> B C H W", - ) - image = self.vae_output_to_image(vae_output) - # Offload all models - self.model_lifecycle_finish(["vae"]) - self.load_models_to_device([]) - return image + image = self.vae_output_to_image(vae_output) + # Offload all models + self.model_lifecycle_finish(["vae"]) + self.load_models_to_device([]) + return image + finally: + if cache_override and previous_cache_enabled is not None: + self.dit.set_image_token_cache_enabled(previous_cache_enabled, clear_existing_cache=True) + elif hasattr(self.dit, "clear_image_token_caches"): + self.dit.clear_image_token_caches() diff --git a/examples/benchmark_qwen_modulate_cuda.py b/examples/benchmark_qwen_modulate_cuda.py new file mode 100644 index 0000000..f0c511b --- /dev/null +++ b/examples/benchmark_qwen_modulate_cuda.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +import argparse +import time + +import torch + +from diffsynth_engine.models.qwen_image.qwen_image_cuda_ext import ( + modulate_forward as modulate_forward_cuda, + modulate_indexed_forward as modulate_indexed_forward_cuda, +) + + +def modulate_pytorch(x: torch.Tensor, mod_params: torch.Tensor): + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + +def modulate_indexed_pytorch(x: torch.Tensor, mod_params: torch.Tensor, index: torch.Tensor): + shift, scale, gate = mod_params.chunk(3, dim=-1) + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + shift_result = torch.where(index == 0, shift_0.unsqueeze(1), shift_1.unsqueeze(1)) + scale_result = torch.where(index == 0, scale_0.unsqueeze(1), scale_1.unsqueeze(1)) + gate_result = torch.where(index == 0, gate_0.unsqueeze(1), gate_1.unsqueeze(1)) + return x * (1 + scale_result) + shift_result, gate_result + + +def benchmark(fn, warmup: int, iters: int) -> float: + for _ in range(warmup): + _ = fn() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(iters): + _ = fn() + torch.cuda.synchronize() + return (time.perf_counter() - start) * 1000.0 / iters + + +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark Qwen modulate CUDA kernel vs PyTorch.") + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--seq", type=int, default=4096) + parser.add_argument("--dim", type=int, default=3072) + parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--check-iters", type=int, default=3) + parser.add_argument("--indexed", action="store_true", help="Benchmark indexed modulation path.") + parser.add_argument("--seed", type=int, default=0) + return parser.parse_args() + + +def dtype_from_str(dtype_name: str) -> torch.dtype: + if dtype_name == "bf16": + return torch.bfloat16 + if dtype_name == "fp16": + return torch.float16 + return torch.float32 + + +def main(): + args = parse_args() + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required to run this benchmark.") + + torch.manual_seed(args.seed) + dtype = dtype_from_str(args.dtype) + device = "cuda" + + x = torch.randn(args.batch, args.seq, args.dim, device=device, dtype=dtype).contiguous() + + if args.indexed: + mod_params = torch.randn(args.batch * 2, args.dim * 3, device=device, dtype=dtype).contiguous() + # Shared index pattern across batch, matching model behavior. + index = torch.randint(0, 2, (1, args.seq, 1), device=device, dtype=torch.int32).contiguous() + ref_fn = lambda: modulate_indexed_pytorch(x, mod_params, index) + cuda_fn = lambda: modulate_indexed_forward_cuda(x, mod_params, index) + mode_name = "indexed" + else: + mod_params = torch.randn(args.batch, args.dim * 3, device=device, dtype=dtype).contiguous() + ref_fn = lambda: modulate_pytorch(x, mod_params) + cuda_fn = lambda: modulate_forward_cuda(x, mod_params) + mode_name = "plain" + + out_cuda = cuda_fn() + if out_cuda is None: + raise RuntimeError("CUDA extension unavailable. Set QWEN_IMAGE_CUDA_EXT_WARN=1 for build errors.") + + atol = 3e-3 if dtype in (torch.float16, torch.bfloat16) else 1e-5 + rtol = 3e-3 if dtype in (torch.float16, torch.bfloat16) else 1e-5 + + max_abs = 0.0 + max_rel = 0.0 + for _ in range(args.check_iters): + y_ref, g_ref = ref_fn() + y_cuda, g_cuda = cuda_fn() + if y_cuda is None or g_cuda is None: + raise RuntimeError("CUDA extension became unavailable during correctness check.") + for a, b in ((y_ref, y_cuda), (g_ref, g_cuda)): + diff = (a - b).abs() + denom = a.abs().clamp_min(1e-6) + max_abs = max(max_abs, diff.max().item()) + max_rel = max(max_rel, (diff / denom).max().item()) + if not torch.allclose(a, b, atol=atol, rtol=rtol): + raise AssertionError( + f"Correctness check failed: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}, " + f"atol={atol}, rtol={rtol}" + ) + + ms_ref = benchmark(ref_fn, args.warmup, args.iters) + ms_cuda = benchmark(cuda_fn, args.warmup, args.iters) + speedup = ms_ref / ms_cuda + + print("=== Qwen Modulate Comparison ===") + print(f"mode: {mode_name}") + print(f"shape: B={args.batch}, S={args.seq}, D={args.dim}, dtype={args.dtype}") + print(f"correctness: PASS (max_abs={max_abs:.6e}, max_rel={max_rel:.6e})") + print(f"pytorch: {ms_ref:.4f} ms/iter") + print(f"cuda : {ms_cuda:.4f} ms/iter") + print(f"speedup: {speedup:.3f}x ({(1.0 - ms_cuda / ms_ref) * 100.0:+.2f}% vs pytorch)") + + +if __name__ == "__main__": + main() diff --git a/examples/benchmark_qwen_rotary_cuda.py b/examples/benchmark_qwen_rotary_cuda.py new file mode 100644 index 0000000..e5b7ad1 --- /dev/null +++ b/examples/benchmark_qwen_rotary_cuda.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +import argparse +import time + +import torch + +from diffsynth_engine.models.qwen_image.qwen_image_cuda_ext import rotary_emb_forward as rotary_emb_forward_cuda + + +def rotary_emb_pytorch(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) + return x_out.type_as(x) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Compare Qwen rotary CUDA kernel against PyTorch implementation.") + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--seq", type=int, default=4096) + parser.add_argument("--heads", type=int, default=24) + parser.add_argument("--dim", type=int, default=128, help="Head dim (must be even).") + parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--check-iters", type=int, default=5) + parser.add_argument("--seed", type=int, default=0) + return parser.parse_args() + + +def dtype_from_str(dtype_name: str) -> torch.dtype: + if dtype_name == "bf16": + return torch.bfloat16 + if dtype_name == "fp16": + return torch.float16 + return torch.float32 + + +def benchmark(fn, x: torch.Tensor, freqs: torch.Tensor, warmup: int, iters: int) -> float: + for _ in range(warmup): + _ = fn(x, freqs) + torch.cuda.synchronize() + + start = time.perf_counter() + for _ in range(iters): + _ = fn(x, freqs) + torch.cuda.synchronize() + return (time.perf_counter() - start) * 1000.0 / iters + + +def main(): + args = parse_args() + if args.dim % 2 != 0: + raise ValueError("--dim must be even") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required to run this benchmark.") + + torch.manual_seed(args.seed) + device = "cuda" + dtype = dtype_from_str(args.dtype) + + x = torch.randn(args.batch, args.seq, args.heads, args.dim, device=device, dtype=dtype).contiguous() + phase = torch.randn(args.seq, args.dim // 2, device=device, dtype=torch.float32) + freqs = torch.polar(torch.ones_like(phase), phase).contiguous() + + y_cuda = rotary_emb_forward_cuda(x, freqs) + if y_cuda is None: + raise RuntimeError( + "CUDA extension failed to load/compile. " + "Set QWEN_IMAGE_CUDA_EXT_WARN=1 to see full build errors." + ) + + atol = 3e-3 if dtype in (torch.float16, torch.bfloat16) else 1e-5 + rtol = 3e-3 if dtype in (torch.float16, torch.bfloat16) else 1e-5 + + max_abs = 0.0 + max_rel = 0.0 + for _ in range(args.check_iters): + x_check = torch.randn_like(x) + y_pt = rotary_emb_pytorch(x_check, freqs) + y_cuda = rotary_emb_forward_cuda(x_check, freqs) + if y_cuda is None: + raise RuntimeError("CUDA extension became unavailable during correctness check.") + diff = (y_cuda - y_pt).abs() + denom = y_pt.abs().clamp_min(1e-6) + max_abs = max(max_abs, diff.max().item()) + max_rel = max(max_rel, (diff / denom).max().item()) + if not torch.allclose(y_cuda, y_pt, atol=atol, rtol=rtol): + raise AssertionError( + f"Correctness check failed: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}, " + f"atol={atol}, rtol={rtol}" + ) + + ms_pt = benchmark(rotary_emb_pytorch, x, freqs, args.warmup, args.iters) + ms_cuda = benchmark(rotary_emb_forward_cuda, x, freqs, args.warmup, args.iters) + speedup = ms_pt / ms_cuda + + print("=== Qwen Rotary Comparison ===") + print(f"shape: B={args.batch}, S={args.seq}, H={args.heads}, D={args.dim}, dtype={args.dtype}") + print(f"correctness: PASS (max_abs={max_abs:.6e}, max_rel={max_rel:.6e})") + print(f"pytorch: {ms_pt:.4f} ms/iter") + print(f"cuda : {ms_cuda:.4f} ms/iter") + print(f"speedup: {speedup:.3f}x ({(1.0 - ms_cuda / ms_pt) * 100.0:+.2f}% vs pytorch)") + + +if __name__ == "__main__": + main() diff --git a/test.py b/test.py new file mode 100644 index 0000000..9451f72 --- /dev/null +++ b/test.py @@ -0,0 +1,125 @@ +import os +import sys +from typing import Any + +# Add project root so data.dataset_creator.image_utils can be imported +sys.path.insert(0, "/home/bingchen/image-model-studio") + +import json, glob +import torch +from PIL import Image +from data.dataset_creator.image_utils import concat_image_list, put_two_id_images_into_one_top_bottom_image, stack_image_list + +import torch +from PIL import Image +from data.dataset_creator.image_utils import put_two_id_images_into_one_top_bottom_image + +#======================= model creating =================================== +from diffsynth_engine import QwenImagePipeline, QwenImagePipelineConfig, fetch_model +from tqdm import tqdm + + + + +""" +gsutil cp gs://uscentral1_ephemeral/multimodal/model_training/bingchen/qwen_image_srpo_pickscore_v1/step-400.safetensors /mnt/localdisk/bingchen/models/srpo_trial1_4h.safetensors + +""" + + +model_path_root = "/data/bingchen/models/Qwen/Qwen-Image-Edit-2511/" +""" +mkdir -p /mnt/localdisk/bingchen/.cache/diffsynth/Qwen/ +cp -r /data/bingchen/models/Qwen/Qwen-Image-Edit-2511 /mnt/localdisk/bingchen/.cache/diffsynth/Qwen/ +""" +model_path_root = "/mnt/localdisk/bingchen/.cache/diffsynth/Qwen/Qwen-Image-Edit-2511/" +def get_ckpt_paths(root, name_pattern): + return sorted(glob.glob(os.path.join(root, name_pattern))) + +config = QwenImagePipelineConfig.basic_config( + model_path=get_ckpt_paths(model_path_root, name_pattern="transformer/*.safetensors"), + encoder_path=get_ckpt_paths(model_path_root, name_pattern="text_encoder/*.safetensors"), + vae_path=get_ckpt_paths(model_path_root, name_pattern="vae/*.safetensors"), + parallelism=1, + use_zero_cond_t=True, + ) +pipeline = QwenImagePipeline.from_pretrained(config) +pipeline.vae_tiled = True +vae_tile_size = 128 +pipeline.vae_tile_size = (vae_tile_size, vae_tile_size) +pipeline.vae_tile_stride = (96, 96) +#pipeline.enable_cpu_offload(offload_mode="cpu_offload", offload_to_disk=False) +pipeline.compile() + +@torch.inference_mode() +def gen(edit_images, prompt, target_resolution, cfg=1, num_steps=8, seed=-1): + if isinstance(cfg, list): + cfg_list = cfg + cfg = 1 + else: + cfg_list = None + inputs = { + "input_image": edit_images, + "prompt": prompt, + #"prompt": f"Put characters in picture 1 and picture 2, into the same position and style as in picture 3. With picture 1's character on the left, picture 2's character on the right. And in the same artistic style of picture 3.", + "cfg_scale": cfg, + "cfg_list": cfg_list, + "negative_prompt": "interwind arms, distortion, warped text, duplicate faces.", + "num_inference_steps": num_steps, + "seed": seed, + "height": target_resolution[1], + "width": target_resolution[0], + } + cur_result = pipeline(**inputs) + return cur_result + +infer_step = 8 +step_lora_dict = { + 4: "/data/bingchen/models/Qwen/Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors", + 8: "/data/bingchen/models/Qwen/Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors", +} + +trained_lora_ckpts = [ + "/mnt/localdisk/bingchen/models/srpo_trial1_4h.safetensors", +] + +pipeline.load_lora(path=trained_lora_ckpts[0], scale=1.0, fused=False) + + +#====================== prepare data ===================== +data_root = "/data/bingchen/dataset/srpo_hpdv2/photo.json" +with open(data_root, "r") as f: + prompt_list = json.load(f) +target_resolution = (768, 1360) +cfg = 1 #[2,2,1,1,1,1,1,1] +#target_resolution = (1440, 2560) +tester_nbr = 8 + +seed = 3333 #9999 + +output_root = "tmp_testing_outputs/srpo_test_v1" +os.makedirs(output_root, exist_ok=True) + + +pipeline.load_lora(path=trained_lora_ckpts[0], scale=1.0, fused=False) +new_col = [] +for pi, cur_prompt in enumerate(prompt_list[:2]): + + res = gen( + edit_images=None, + prompt=cur_prompt, + target_resolution=target_resolution, + cfg=4, + num_steps=40, + seed=seed) + new_col.append(res) + +new_col = stack_image_list(new_col) +new_im_path = output_root+f'/tmp_qwen_{infer_step}_trained.jpg' +new_col.save(new_im_path) + + +""" +org compile: 2.78it/s +cuda 2 kernels: 2.79it/s +""" \ No newline at end of file