diff --git a/hackable_diffusion/lib/utils.py b/hackable_diffusion/lib/utils.py index 400bf09..4936c6b 100644 --- a/hackable_diffusion/lib/utils.py +++ b/hackable_diffusion/lib/utils.py @@ -375,9 +375,9 @@ def get_dummy_batch_fixed_dtype( def _get_dummy(shape: tuple[int, ...]) -> jnp.ndarray: if only_first_axis: - return jnp.empty(shape=(shape[0],), dtype=dtype) + return jnp.zeros(shape=(shape[0],), dtype=dtype) else: - return jnp.empty(shape=shape, dtype=dtype) + return jnp.zeros(shape=shape, dtype=dtype) return jax.tree.map( _get_dummy, @@ -409,9 +409,9 @@ def get_dummy_batch( def _get_dummy(shape: tuple[int, ...], dtype: DType) -> jnp.ndarray: if only_first_axis: - return jnp.empty(shape=(shape[0],), dtype=dtype) + return jnp.zeros(shape=(shape[0],), dtype=dtype) else: - return jnp.empty(shape=shape, dtype=dtype) + return jnp.zeros(shape=shape, dtype=dtype) return jax.tree.map( _get_dummy,