Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <torch/extension.h>

torch::Tensor rotary_emb_forward_cuda(torch::Tensor x, torch::Tensor freqs_cis);
std::vector<torch::Tensor> modulate_forward_cuda(torch::Tensor x, torch::Tensor mod_params);
std::vector<torch::Tensor> modulate_indexed_forward_cuda(
torch::Tensor x,
torch::Tensor mod_params,
torch::Tensor index
);

torch::Tensor rotary_emb_forward(torch::Tensor x, torch::Tensor freqs_cis) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(freqs_cis.is_cuda(), "freqs_cis must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "x must be contiguous");
TORCH_CHECK(freqs_cis.is_contiguous(), "freqs_cis must be contiguous");
TORCH_CHECK(x.dim() == 4, "x must have shape [B, S, H, D]");
TORCH_CHECK(freqs_cis.dim() == 2, "freqs_cis must have shape [S, D/2]");
TORCH_CHECK(x.size(1) == freqs_cis.size(0), "x sequence dimension must match freqs_cis");
TORCH_CHECK((x.size(3) % 2) == 0, "last dim of x must be even");
TORCH_CHECK(freqs_cis.scalar_type() == at::kComplexFloat, "freqs_cis must be complex64");
TORCH_CHECK(
(x.scalar_type() == at::kFloat) || (x.scalar_type() == at::kHalf) || (x.scalar_type() == at::kBFloat16),
"x dtype must be float32, float16, or bfloat16"
);
TORCH_CHECK(x.size(3) / 2 == freqs_cis.size(1), "freqs_cis second dim must be D/2");
return rotary_emb_forward_cuda(x, freqs_cis);
}

std::vector<torch::Tensor> modulate_forward(torch::Tensor x, torch::Tensor mod_params) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(mod_params.is_cuda(), "mod_params must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "x must be contiguous");
TORCH_CHECK(mod_params.is_contiguous(), "mod_params must be contiguous");
TORCH_CHECK(x.dim() == 3, "x must have shape [B, S, D]");
TORCH_CHECK(mod_params.dim() == 2, "mod_params must have shape [B, 3D]");
TORCH_CHECK(x.size(0) == mod_params.size(0), "mod_params batch must equal x batch");
TORCH_CHECK(mod_params.size(1) == x.size(2) * 3, "mod_params last dim must be 3 * D");
TORCH_CHECK(mod_params.scalar_type() == x.scalar_type(), "mod_params dtype must match x dtype");
TORCH_CHECK(
(x.scalar_type() == at::kFloat) || (x.scalar_type() == at::kHalf) || (x.scalar_type() == at::kBFloat16),
"x dtype must be float32, float16, or bfloat16"
);
return modulate_forward_cuda(x, mod_params);
}

std::vector<torch::Tensor> modulate_indexed_forward(torch::Tensor x, torch::Tensor mod_params, torch::Tensor index) {
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(mod_params.is_cuda(), "mod_params must be a CUDA tensor");
TORCH_CHECK(index.is_cuda(), "index must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "x must be contiguous");
TORCH_CHECK(mod_params.is_contiguous(), "mod_params must be contiguous");
TORCH_CHECK(index.is_contiguous(), "index must be contiguous");
TORCH_CHECK(x.dim() == 3, "x must have shape [B, S, D]");
TORCH_CHECK(mod_params.dim() == 2, "mod_params must have shape [2B, 3D]");
TORCH_CHECK(index.dim() == 3, "index must have shape [B or 1, S, 1]");
TORCH_CHECK(index.size(1) == x.size(1), "index sequence dim must match x sequence dim");
TORCH_CHECK(index.size(2) == 1, "index last dim must be 1");
TORCH_CHECK(mod_params.size(0) == x.size(0) * 2, "mod_params batch must be 2 * x batch");
TORCH_CHECK(mod_params.size(1) == x.size(2) * 3, "mod_params last dim must be 3 * D");
TORCH_CHECK(mod_params.scalar_type() == x.scalar_type(), "mod_params dtype must match x dtype");
TORCH_CHECK(index.scalar_type() == at::kLong || index.scalar_type() == at::kInt, "index must be int32 or int64");
TORCH_CHECK(
(x.scalar_type() == at::kFloat) || (x.scalar_type() == at::kHalf) || (x.scalar_type() == at::kBFloat16),
"x dtype must be float32, float16, or bfloat16"
);
return modulate_indexed_forward_cuda(x, mod_params, index);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rotary_emb_forward", &rotary_emb_forward, "Qwen rotary embedding forward (CUDA)");
m.def("modulate_forward", &modulate_forward, "Qwen modulation forward (CUDA)");
m.def("modulate_indexed_forward", &modulate_indexed_forward, "Qwen indexed modulation forward (CUDA)");
}
250 changes: 250 additions & 0 deletions diffsynth_engine/models/qwen_image/csrc/qwen_image_rotary_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

template <typename scalar_t>
__global__ void rotary_emb_kernel(
const scalar_t* __restrict__ x,
const c10::complex<float>* __restrict__ freqs,
scalar_t* __restrict__ out,
int64_t bsz,
int64_t seq,
int64_t heads,
int64_t dim
) {
const int64_t half_dim = dim >> 1;
const int64_t total_pairs = bsz * seq * heads * half_dim;
const int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx >= total_pairs) return;

int64_t t = idx;
const int64_t pair_idx = t % half_dim;
t /= half_dim;
const int64_t head_idx = t % heads;
t /= heads;
const int64_t seq_idx = t % seq;
const int64_t batch_idx = t / seq;

const int64_t base = ((batch_idx * seq + seq_idx) * heads + head_idx) * dim + (pair_idx << 1);
const float x0 = static_cast<float>(x[base]);
const float x1 = static_cast<float>(x[base + 1]);
const c10::complex<float> f = freqs[seq_idx * half_dim + pair_idx];
const float fr = f.real();
const float fi = f.imag();

const float y0 = x0 * fr - x1 * fi;
const float y1 = x0 * fi + x1 * fr;
out[base] = static_cast<scalar_t>(y0);
out[base + 1] = static_cast<scalar_t>(y1);
}

torch::Tensor rotary_emb_forward_cuda(torch::Tensor x, torch::Tensor freqs_cis) {
auto out = torch::empty_like(x);
const int64_t bsz = x.size(0);
const int64_t seq = x.size(1);
const int64_t heads = x.size(2);
const int64_t dim = x.size(3);
const int64_t total_pairs = bsz * seq * heads * (dim >> 1);

if (total_pairs == 0) {
return out;
}

constexpr int threads = 256;
const int blocks = static_cast<int>((total_pairs + threads - 1) / threads);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();

AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf,
at::kBFloat16,
x.scalar_type(),
"qwen_image_rotary_emb_cuda",
[&] {
rotary_emb_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
freqs_cis.data_ptr<c10::complex<float>>(),
out.data_ptr<scalar_t>(),
bsz,
seq,
heads,
dim
);
}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return out;
}


template <typename scalar_t>
__global__ void modulate_kernel(
const scalar_t* __restrict__ x,
const scalar_t* __restrict__ shift,
const scalar_t* __restrict__ scale,
scalar_t* __restrict__ modulated,
int64_t bsz,
int64_t seq,
int64_t dim
) {
const int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t total = bsz * seq * dim;
if (idx >= total) return;

const int64_t d = idx % dim;
const int64_t t = idx / dim;
const int64_t b = t / seq;
const int64_t base_vec = b * dim + d;

const scalar_t xv = x[idx];
const scalar_t shiftv = shift[base_vec];
const scalar_t scalev = scale[base_vec];
modulated[idx] = xv * (static_cast<scalar_t>(1.0f) + scalev) + shiftv;
}


template <typename scalar_t, typename index_t>
__global__ void modulate_indexed_kernel(
const scalar_t* __restrict__ x,
const scalar_t* __restrict__ shift,
const scalar_t* __restrict__ scale,
const scalar_t* __restrict__ gate,
const index_t* __restrict__ index,
scalar_t* __restrict__ modulated,
scalar_t* __restrict__ gate_out,
int64_t bsz,
int64_t seq,
int64_t dim,
int64_t index_bsz
) {
const int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t total = bsz * seq * dim;
if (idx >= total) return;

const int64_t d = idx % dim;
const int64_t t = idx / dim;
const int64_t s = t % seq;
const int64_t b = t / seq;

const int64_t ib = (index_bsz == 1) ? 0 : b;
const int64_t index_offset = (ib * seq + s);
const int64_t select = static_cast<int64_t>(index[index_offset]) == 0 ? b : (b + bsz);
const int64_t base_vec = select * dim + d;

const scalar_t xv = x[idx];
const scalar_t shiftv = shift[base_vec];
const scalar_t scalev = scale[base_vec];
modulated[idx] = xv * (static_cast<scalar_t>(1.0f) + scalev) + shiftv;
gate_out[idx] = gate[base_vec];
}


std::vector<torch::Tensor> modulate_forward_cuda(torch::Tensor x, torch::Tensor mod_params) {
const int64_t bsz = x.size(0);
const int64_t seq = x.size(1);
const int64_t dim = x.size(2);
auto modulated = torch::empty_like(x);

const int64_t total = bsz * seq * dim;
if (total == 0) {
auto gate_out = mod_params.narrow(1, dim * 2, dim).unsqueeze(1).contiguous();
return {modulated, gate_out};
}

auto shift = mod_params.narrow(1, 0, dim).contiguous();
auto scale = mod_params.narrow(1, dim, dim).contiguous();
constexpr int threads = 256;
const int blocks = static_cast<int>((total + threads - 1) / threads);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();

AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf,
at::kBFloat16,
x.scalar_type(),
"qwen_image_modulate_cuda",
[&] {
modulate_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
shift.data_ptr<scalar_t>(),
scale.data_ptr<scalar_t>(),
modulated.data_ptr<scalar_t>(),
bsz,
seq,
dim
);
}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto gate_out = mod_params.narrow(1, dim * 2, dim).unsqueeze(1).contiguous();
return {modulated, gate_out};
}


std::vector<torch::Tensor> modulate_indexed_forward_cuda(
torch::Tensor x,
torch::Tensor mod_params,
torch::Tensor index
) {
const int64_t bsz = x.size(0);
const int64_t seq = x.size(1);
const int64_t dim = x.size(2);
const int64_t index_bsz = index.size(0);
auto modulated = torch::empty_like(x);
auto gate_out = torch::empty_like(x);

const int64_t total = bsz * seq * dim;
if (total == 0) {
return {modulated, gate_out};
}

auto shift = mod_params.narrow(1, 0, dim).contiguous();
auto scale = mod_params.narrow(1, dim, dim).contiguous();
auto gate = mod_params.narrow(1, dim * 2, dim).contiguous();
auto index_2d = index.squeeze(-1);
constexpr int threads = 256;
const int blocks = static_cast<int>((total + threads - 1) / threads);
cudaStream_t stream = c10::cuda::getCurrentCUDAStream().stream();

AT_DISPATCH_FLOATING_TYPES_AND2(
at::kHalf,
at::kBFloat16,
x.scalar_type(),
"qwen_image_modulate_indexed_cuda",
[&] {
if (index_2d.scalar_type() == at::kLong) {
modulate_indexed_kernel<scalar_t, int64_t><<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
shift.data_ptr<scalar_t>(),
scale.data_ptr<scalar_t>(),
gate.data_ptr<scalar_t>(),
index_2d.data_ptr<int64_t>(),
modulated.data_ptr<scalar_t>(),
gate_out.data_ptr<scalar_t>(),
bsz,
seq,
dim,
index_bsz
);
} else {
modulate_indexed_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
x.data_ptr<scalar_t>(),
shift.data_ptr<scalar_t>(),
scale.data_ptr<scalar_t>(),
gate.data_ptr<scalar_t>(),
index_2d.data_ptr<int>(),
modulated.data_ptr<scalar_t>(),
gate_out.data_ptr<scalar_t>(),
bsz,
seq,
dim,
index_bsz
);
}
}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return {modulated, gate_out};
}
Loading