diff --git a/hackable_diffusion/kdiff/core.py b/hackable_diffusion/kdiff/core.py index 970b952..225af46 100644 --- a/hackable_diffusion/kdiff/core.py +++ b/hackable_diffusion/kdiff/core.py @@ -111,7 +111,7 @@ def __call__( self, x0: DataTree, cond: Conditioning | None = None, - ) -> dict[str, dict[str, Array] | Array]: + ) -> dict[str, dict[str, Array] | Array | dict[str, dict[str, Array]]]: """Run the diffusion training step. Samples timesteps, corrupts the input data according to the corruption