Skip to content

Commit cb7cbac

Browse files
committed
added glm4_moe_lite supports
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent e82a0ad commit cb7cbac

5 files changed

Lines changed: 141 additions & 5 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Defuser currently supports the following `transformers==5.3.0` `model_type` valu
5353
| Model type | Defused op performed |
5454
| --- | --- |
5555
| `glm4_moe` | Replaces `Glm4MoeMoE` with a defused per-expert linear MoE block. |
56+
| `glm4_moe_lite` | Replaces `Glm4MoeLiteMoE` with a defused per-expert linear MoE block.|
5657
| `glm4v` | Replaces the fused text MLP with split `gate_proj`, `up_proj`, and `down_proj` layers. Also splits fused checkpoint `mlp.gate_up_proj.weight` into `mlp.gate_proj.weight` + `mlp.up_proj.weight`. |
5758
| `mixtral` | Replaces `MixtralSparseMoeBlock` with `LinearMixtralSparseMoeBlock`. Also remaps legacy Mixtral checkpoint keys and splits fused expert `gate_up_proj` tensors into per-expert `gate_proj` and `up_proj`, plus per-expert `down_proj`. |
5859
| `qwen2_moe` | Replaces `Qwen2MoeSparseMoeBlock` with a defused per-expert linear MoE block. |

defuser/model_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ class PATCH(str, Enum):
135135
},
136136
"glm4_moe_lite": {
137137
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
138+
PATCH.REPLACE_MODULE: [
139+
(
140+
"transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteMoE",
141+
"defuser.modeling.unfused_moe.glm4_moe_lite.LinearGlm4MoeLiteMoE",
142+
)
143+
],
138144
},
139145
"glm4v": {
140146
"min_transformers_version": MIN_SUPPORTED_TRANSFORMERS_VERSION,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-FileCopyrightText: 2026 ModelCloud.ai
2+
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
3+
# SPDX-License-Identifier: Apache-2.0
4+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
5+
6+
# Adapted from intel/auto-round
7+
# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/unfused_moe/glm_moe_light.py
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
class LinearGlm4MoeLiteMoE(nn.Module):
13+
"""
14+
A mixed expert module containing shared experts.
15+
"""
16+
17+
def __init__(self, config):
18+
super().__init__()
19+
self.config = config
20+
self.num_experts = config.num_local_experts
21+
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import Glm4MoeLiteMLP, Glm4MoeLiteTopkRouter
22+
23+
self.experts = nn.ModuleList(
24+
[Glm4MoeLiteMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
25+
)
26+
27+
self.gate = Glm4MoeLiteTopkRouter(config)
28+
self.shared_experts = Glm4MoeLiteMLP(
29+
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
30+
)
31+
self.n_routed_experts = config.n_routed_experts
32+
self.n_group = config.n_group
33+
self.topk_group = config.topk_group
34+
self.norm_topk_prob = config.norm_topk_prob
35+
self.routed_scaling_factor = config.routed_scaling_factor
36+
self.top_k = config.num_experts_per_tok
37+
38+
def experts_forward(
39+
self,
40+
hidden_states: torch.Tensor,
41+
top_k_index: torch.Tensor,
42+
top_k_weights: torch.Tensor,
43+
) -> torch.Tensor:
44+
""" """
45+
final_hidden_states = torch.zeros_like(hidden_states)
46+
with torch.no_grad():
47+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
48+
expert_mask = expert_mask.permute(2, 1, 0)
49+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
50+
51+
for expert_idx in expert_hit:
52+
expert_idx = expert_idx[0]
53+
if expert_idx == self.num_experts:
54+
continue
55+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
56+
current_state = hidden_states[token_idx]
57+
expert_layer = self.experts[expert_idx]
58+
current_hidden_states = expert_layer(current_state)
59+
# gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
60+
# current_hidden_states = self.act_fn(gate) * up
61+
# current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
62+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
63+
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
64+
65+
return final_hidden_states
66+
67+
def route_tokens_to_experts(self, router_logits):
68+
router_logits = router_logits.sigmoid()
69+
router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
70+
group_scores = (
71+
router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
72+
.topk(2, dim=-1)[0]
73+
.sum(dim=-1)
74+
)
75+
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
76+
group_mask = torch.zeros_like(group_scores)
77+
group_mask.scatter_(1, group_idx, 1)
78+
score_mask = (
79+
group_mask.unsqueeze(-1)
80+
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
81+
.reshape(-1, self.n_routed_experts)
82+
)
83+
scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
84+
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
85+
topk_weights = router_logits.gather(1, topk_indices)
86+
if self.norm_topk_prob:
87+
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
88+
topk_weights /= denominator
89+
topk_weights = topk_weights * self.routed_scaling_factor
90+
return topk_indices, topk_weights
91+
92+
def forward(self, hidden_states):
93+
residuals = hidden_states
94+
orig_shape = hidden_states.shape
95+
router_logits = self.gate(hidden_states)
96+
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
97+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
98+
hidden_states = self.experts_forward(hidden_states, topk_indices, topk_weights).view(*orig_shape)
99+
hidden_states = hidden_states + self.shared_experts(residuals)
100+
return hidden_states

tests/test_convert_model.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from torch import nn
1111
from transformers.core_model_loading import WeightConverter
1212
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM, Glm4MoeMoE
13+
from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import Glm4MoeLiteConfig
14+
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import (
15+
Glm4MoeLiteForCausalLM,
16+
Glm4MoeLiteMoE,
17+
)
1318
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig
1419
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
1520
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
@@ -48,6 +53,7 @@
4853
from defuser.model_registry import MODEL_CONFIG, PATCH
4954
from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model
5055
from defuser.modeling.unfused_moe.glm4_moe import LinearGlm4MoeMoE
56+
from defuser.modeling.unfused_moe.glm4_moe_lite import LinearGlm4MoeLiteMoE
5157
from defuser.modeling.unfused_moe.mixtral import LinearMixtralSparseMoeBlock
5258
from defuser.modeling.unfused_moe.qwen2_moe import LinearQwen2MoeSparseMoeBlock
5359
from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock
@@ -61,9 +67,9 @@
6167

6268

6369

64-
def _tiny_moe_config(config_cls):
70+
def _tiny_moe_config(config_cls, num_hidden_layers: int=1):
6571
return config_cls(
66-
num_hidden_layers=1,
72+
num_hidden_layers=num_hidden_layers,
6773
hidden_size=64,
6874
intermediate_size=128,
6975
moe_intermediate_size=32,
@@ -936,6 +942,29 @@ def test_defused_models_preserve_output_router_logits_capture():
936942
assert len(outputs.router_logits) == 1
937943
assert outputs.router_logits[0].shape == (3, model.config.num_experts)
938944

945+
def test_glm4_moe_lite():
946+
model_type = "glm4_moe_lite"
947+
replace_fused_blocks(model_type)
948+
949+
model = Glm4MoeLiteForCausalLM(_tiny_moe_config(Glm4MoeLiteConfig, num_hidden_layers=2))
950+
assert model.config.model_type == model_type
951+
952+
# In GLM4-MoE-Lite, the `mlp.experts` module is present only starting from the second layer.
953+
converted = convert_model(model, max_layers=2)
954+
assert not converted
955+
956+
_assert_unfused_expert_module(model.model.layers[1].mlp.experts)
957+
958+
959+
def test_glm4_moe_lite_defused_forward_matches_fused_math():
960+
config = _tiny_moe_config(Glm4MoeLiteConfig)
961+
hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)
962+
963+
_assert_sparse_moe_defused_matches_fused_math(
964+
Glm4MoeLiteMoE(config),
965+
LinearGlm4MoeLiteMoE(config),
966+
hidden_states,
967+
)
939968

940969
def test_glm4v_checkpoint_mapping_splits_gate_up_proj():
941970
from defuser.defuser import get_checkpoint_conversion_mapping

tests/test_meta_model_defusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,13 @@ def _validate_defused_module(case: dict, module) -> None:
382382
},
383383
{
384384
"model_type": "glm4_moe_lite",
385-
"mode": "convert",
385+
"mode": "replace",
386386
"model_module": "transformers.models.glm4_moe_lite.modeling_glm4_moe_lite",
387387
"model_class": "Glm4MoeLiteForCausalLM",
388388
"config_module": "transformers.models.glm4_moe_lite.configuration_glm4_moe_lite",
389389
"config_class": "Glm4MoeLiteConfig",
390-
"target_class_paths": ("transformers.models.glm4_moe_lite.modeling_glm4_moe_lite.Glm4MoeLiteNaiveMoe",),
391-
"validator": "experts",
390+
"target_class_paths": ("defuser.modeling.unfused_moe.glm4_moe_lite.LinearGlm4MoeLiteMoE",),
391+
"validator": "sparse_block",
392392
},
393393
{
394394
"model_type": "glm4v",

0 commit comments

Comments
 (0)