2626################################################################################
2727
2828EmbeddingMergeMethod = arch_typing .EmbeddingMergeMethod
29- ConditioningMechanism = arch_typing . ConditioningMechanism
29+
3030
3131################################################################################
3232# MARK: Tests
@@ -39,13 +39,13 @@ class EncodeConditioningTest(parameterized.TestCase):
3939 (
4040 'test1' ,
4141 EmbeddingMergeMethod .SUM ,
42- ConditioningMechanism . ADAPTIVE_NORM ,
42+ 'adaptive_norm' ,
4343 True ,
4444 ),
4545 (
4646 'test2' ,
4747 EmbeddingMergeMethod .CONCAT ,
48- ConditioningMechanism . CROSS_ATTENTION ,
48+ 'cross_attention' ,
4949 False ,
5050 ),
5151 )
@@ -110,13 +110,13 @@ def test_basic(
110110 (
111111 'test1' ,
112112 EmbeddingMergeMethod .SUM ,
113- ConditioningMechanism . ADAPTIVE_NORM ,
113+ 'adaptive_norm' ,
114114 True ,
115115 ),
116116 (
117117 'test2' ,
118118 EmbeddingMergeMethod .CONCAT ,
119- ConditioningMechanism . CROSS_ATTENTION ,
119+ 'cross_attention' ,
120120 False ,
121121 ),
122122 )
@@ -182,13 +182,13 @@ def test_mlp_embedder(
182182 (
183183 'test1' ,
184184 EmbeddingMergeMethod .SUM ,
185- ConditioningMechanism . ADAPTIVE_NORM ,
185+ 'adaptive_norm' ,
186186 True ,
187187 ),
188188 (
189189 'test2' ,
190190 EmbeddingMergeMethod .CONCAT ,
191- ConditioningMechanism . CROSS_ATTENTION ,
191+ 'cross_attention' ,
192192 False ,
193193 ),
194194 )
@@ -257,13 +257,13 @@ def test_mlp_embedder_process_multiple_keys(
257257 (
258258 'test1' ,
259259 EmbeddingMergeMethod .SUM ,
260- ConditioningMechanism . ADAPTIVE_NORM ,
260+ 'adaptive_norm' ,
261261 True ,
262262 ),
263263 (
264264 'test2' ,
265265 EmbeddingMergeMethod .CONCAT ,
266- ConditioningMechanism . CROSS_ATTENTION ,
266+ 'cross_attention' ,
267267 False ,
268268 ),
269269 )
@@ -330,8 +330,8 @@ def test_field_selector_embedder(self):
330330 )
331331 conditioning_encoders = {'image' : image_selector }
332332 conditioning_rules = {
333- 'time' : ConditioningMechanism . ADAPTIVE_NORM ,
334- 'image' : ConditioningMechanism . CROSS_ATTENTION ,
333+ 'time' : 'adaptive_norm' ,
334+ 'image' : 'cross_attention' ,
335335 }
336336 embedding_merging_method = EmbeddingMergeMethod .CONCAT
337337
@@ -353,18 +353,18 @@ def test_field_selector_embedder(self):
353353 {'params' : params }, t , c , is_training = False , rngs = {'dropout' : rng }
354354 )
355355
356- self .assertIn (ConditioningMechanism . CROSS_ATTENTION , output )
356+ self .assertIn ('cross_attention' , output )
357357 self .assertEqual (
358- output [ConditioningMechanism . CROSS_ATTENTION ].shape ,
358+ output ['cross_attention' ].shape ,
359359 (batch_size ,) + image_shape ,
360360 )
361361 self .assertTrue (
362- jnp .all (output [ConditioningMechanism . CROSS_ATTENTION ] == c ['image' ])
362+ jnp .all (output ['cross_attention' ] == c ['image' ])
363363 )
364364
365- self .assertIn (ConditioningMechanism . ADAPTIVE_NORM , output )
365+ self .assertIn ('adaptive_norm' , output )
366366 self .assertEqual (
367- output [ConditioningMechanism . ADAPTIVE_NORM ].shape ,
367+ output ['adaptive_norm' ].shape ,
368368 (batch_size , num_features ),
369369 )
370370
@@ -384,8 +384,8 @@ def test_field_selector_embedder_fails_on_missing_key(self):
384384 )
385385 conditioning_encoders = {'image' : image_selector }
386386 conditioning_rules = {
387- 'time' : ConditioningMechanism . ADAPTIVE_NORM ,
388- 'image' : ConditioningMechanism . CROSS_ATTENTION ,
387+ 'time' : 'adaptive_norm' ,
388+ 'image' : 'cross_attention' ,
389389 }
390390 embedding_merging_method = EmbeddingMergeMethod .CONCAT
391391
@@ -411,7 +411,7 @@ def test_field_selector_embedder_fails_on_missing_key(self):
411411 (
412412 'test1' ,
413413 EmbeddingMergeMethod .CONCAT ,
414- ConditioningMechanism . CROSS_ATTENTION ,
414+ 'cross_attention' ,
415415 8 ,
416416 16 ,
417417 False ,
@@ -475,7 +475,7 @@ def test_different_num_features(
475475 (
476476 'test1' ,
477477 EmbeddingMergeMethod .CONCAT ,
478- ConditioningMechanism . CROSS_ATTENTION ,
478+ 'cross_attention' ,
479479 8 ,
480480 9 ,
481481 10 ,
@@ -484,7 +484,7 @@ def test_different_num_features(
484484 (
485485 'test2' ,
486486 EmbeddingMergeMethod .CONCAT ,
487- ConditioningMechanism . CROSS_ATTENTION ,
487+ 'cross_attention' ,
488488 8 ,
489489 9 ,
490490 10 ,
@@ -567,7 +567,7 @@ def test_multilabel(
567567 (
568568 'test1' ,
569569 EmbeddingMergeMethod .CONCAT ,
570- ConditioningMechanism . CROSS_ATTENTION ,
570+ 'cross_attention' ,
571571 8 ,
572572 9 ,
573573 10 ,
@@ -576,7 +576,7 @@ def test_multilabel(
576576 (
577577 'test2' ,
578578 EmbeddingMergeMethod .CONCAT ,
579- ConditioningMechanism . CROSS_ATTENTION ,
579+ 'cross_attention' ,
580580 8 ,
581581 9 ,
582582 10 ,
@@ -668,8 +668,8 @@ def test_dropout(self):
668668 conditioning_embedders = conditioning_encoders ,
669669 embedding_merging_method = EmbeddingMergeMethod .SUM ,
670670 conditioning_rules = {
671- 'time' : ConditioningMechanism . ADAPTIVE_NORM ,
672- 'label' : ConditioningMechanism . ADAPTIVE_NORM ,
671+ 'time' : 'adaptive_norm' ,
672+ 'label' : 'adaptive_norm' ,
673673 },
674674 conditioning_dropout_rate = 1.0 , # Drop all conditioning
675675 )
0 commit comments