File tree Expand file tree Collapse file tree
hackable_diffusion/lib/sampling Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -90,7 +90,7 @@ def initialize(
9090 )
9191
9292 @kt .typechecked
93- def update (
93+ def _update (
9494 self ,
9595 prediction : TargetInfo ,
9696 current_step : DiffusionStep ,
@@ -137,14 +137,28 @@ def update(
137137 aux = dict (),
138138 )
139139
140+ @kt .typechecked
141+ def update (
142+ self ,
143+ prediction : TargetInfo ,
144+ current_step : DiffusionStep ,
145+ next_step_info : StepInfo ,
146+ ) -> DiffusionStep :
147+ return self ._update (
148+ prediction ,
149+ current_step ,
150+ next_step_info ,
151+ stochastic = True ,
152+ )
153+
140154 @kt .typechecked
141155 def finalize (
142156 self ,
143157 prediction : TargetInfo ,
144158 current_step : DiffusionStep ,
145159 last_step_info : StepInfo ,
146160 ) -> DiffusionStep :
147- return self .update (
161+ return self ._update (
148162 prediction ,
149163 current_step ,
150164 last_step_info ,
@@ -381,7 +395,7 @@ def initialize(
381395 )
382396
383397 @kt .typechecked
384- def update (
398+ def _update (
385399 self ,
386400 prediction : TargetInfo ,
387401 current_step : DiffusionStep ,
@@ -424,14 +438,28 @@ def update(
424438 aux = dict (),
425439 )
426440
441+ @kt .typechecked
442+ def update (
443+ self ,
444+ prediction : TargetInfo ,
445+ current_step : DiffusionStep ,
446+ next_step_info : StepInfo ,
447+ ) -> DiffusionStep :
448+ return self ._update (
449+ prediction ,
450+ current_step ,
451+ next_step_info ,
452+ stochastic = True ,
453+ )
454+
427455 @kt .typechecked
428456 def finalize (
429457 self ,
430458 prediction : TargetInfo ,
431459 current_step : DiffusionStep ,
432460 last_step_info : StepInfo ,
433461 ) -> DiffusionStep :
434- return self .update (
462+ return self ._update (
435463 prediction ,
436464 current_step ,
437465 last_step_info ,
Original file line number Diff line number Diff line change @@ -160,9 +160,9 @@ def update(
160160
161161 # Get logits
162162 logits = self .corruption_process .convert_predictions (
163- prediction ,
164- log_xt ,
165- time ,
163+ prediction = prediction ,
164+ xt = log_xt ,
165+ time = time ,
166166 )['logits' ]
167167
168168 # Sample hard token
You can’t perform that action at this time.
0 commit comments