|
38 | 38 | def _get_invalid_num_features_params(): |
39 | 39 | """Generates parameters for testing invalid num_features.""" |
40 | 40 | params = [] |
41 | | - modes = ["sinusoidal_embedding", "random_fourier_embedding"] |
| 41 | + modes = [ |
| 42 | + "sinusoidal_embedding", |
| 43 | + "random_fourier_embedding", |
| 44 | + "additive_embedding", |
| 45 | + ] |
42 | 46 | feature_values = [ |
43 | 47 | ("default", INVALID_INT), |
44 | 48 | ("zero", 0), |
@@ -102,6 +106,10 @@ def test_sequence_embedding_raises_error_on_invalid_num_features( |
102 | 106 | module = sequence_embedders.RandomFourierSequenceEmbedding( |
103 | 107 | num_features=num_features |
104 | 108 | ) |
| 109 | + elif mode == "additive_embedding": |
| 110 | + module = sequence_embedders.AdditiveSequenceEmbedding( |
| 111 | + num_features=num_features |
| 112 | + ) |
105 | 113 | else: |
106 | 114 | self.fail(f"Unknown mode: {mode}") |
107 | 115 | inputs = jnp.arange(self.batch_size) |
@@ -244,6 +252,30 @@ def test_rope_embedding_has_no_params(self): |
244 | 252 | variables = module.init(self.rng, x_rope) |
245 | 253 | self.assertEmpty(variables) |
246 | 254 |
|
| 255 | + # MARK: AdditiveSequenceEmbedding tests |
| 256 | + |
| 257 | + def test_additive_embedding_output_shape(self): |
| 258 | + """Tests the output shape of AdditiveSequenceEmbedding.""" |
| 259 | + module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim) |
| 260 | + variables = module.init({"params": self.rng}, self.x) |
| 261 | + output = module.apply(variables, self.x) |
| 262 | + self.assertEqual(output.shape, self.x.shape) |
| 263 | + |
| 264 | + def test_additive_embedding_params_are_updated(self): |
| 265 | + """Tests that AdditiveSequenceEmbedding params are updated.""" |
| 266 | + module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim) |
| 267 | + variables = module.init({"params": self.rng}, self.x) |
| 268 | + initial_params = variables["params"] |
| 269 | + |
| 270 | + def loss_fn(params): |
| 271 | + output = module.apply({"params": params}, self.x) |
| 272 | + return jnp.sum(output) |
| 273 | + |
| 274 | + grads = jax.grad(loss_fn)(initial_params) |
| 275 | + |
| 276 | + # Check that the gradients are not zero. |
| 277 | + self.assertFalse(jnp.allclose(grads["PositionalEmbeddingTensor"], 0.0)) |
| 278 | + |
247 | 279 |
|
248 | 280 | if __name__ == "__main__": |
249 | 281 | absltest.main() |
0 commit comments