Skip to content

Commit 851faa6

Browse files
ccrepyHackable Diffusion Authors
authored andcommitted
Cleanup *_step_sampler.py - fix some interfaces and use keyword arguments consistently
PiperOrigin-RevId: 911378722
1 parent 1b331f4 commit 851faa6

2 files changed

Lines changed: 35 additions & 7 deletions

File tree

hackable_diffusion/lib/sampling/gaussian_step_sampler.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def initialize(
9090
)
9191

9292
@kt.typechecked
93-
def update(
93+
def _update(
9494
self,
9595
prediction: TargetInfo,
9696
current_step: DiffusionStep,
@@ -137,14 +137,28 @@ def update(
137137
aux=dict(),
138138
)
139139

140+
@kt.typechecked
141+
def update(
142+
self,
143+
prediction: TargetInfo,
144+
current_step: DiffusionStep,
145+
next_step_info: StepInfo,
146+
) -> DiffusionStep:
147+
return self._update(
148+
prediction,
149+
current_step,
150+
next_step_info,
151+
stochastic=True,
152+
)
153+
140154
@kt.typechecked
141155
def finalize(
142156
self,
143157
prediction: TargetInfo,
144158
current_step: DiffusionStep,
145159
last_step_info: StepInfo,
146160
) -> DiffusionStep:
147-
return self.update(
161+
return self._update(
148162
prediction,
149163
current_step,
150164
last_step_info,
@@ -381,7 +395,7 @@ def initialize(
381395
)
382396

383397
@kt.typechecked
384-
def update(
398+
def _update(
385399
self,
386400
prediction: TargetInfo,
387401
current_step: DiffusionStep,
@@ -424,14 +438,28 @@ def update(
424438
aux=dict(),
425439
)
426440

441+
@kt.typechecked
442+
def update(
443+
self,
444+
prediction: TargetInfo,
445+
current_step: DiffusionStep,
446+
next_step_info: StepInfo,
447+
) -> DiffusionStep:
448+
return self._update(
449+
prediction,
450+
current_step,
451+
next_step_info,
452+
stochastic=True,
453+
)
454+
427455
@kt.typechecked
428456
def finalize(
429457
self,
430458
prediction: TargetInfo,
431459
current_step: DiffusionStep,
432460
last_step_info: StepInfo,
433461
) -> DiffusionStep:
434-
return self.update(
462+
return self._update(
435463
prediction,
436464
current_step,
437465
last_step_info,

hackable_diffusion/lib/sampling/simplicial_step_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def update(
160160

161161
# Get logits
162162
logits = self.corruption_process.convert_predictions(
163-
prediction,
164-
log_xt,
165-
time,
163+
prediction=prediction,
164+
xt=log_xt,
165+
time=time,
166166
)['logits']
167167

168168
# Sample hard token

0 commit comments

Comments
 (0)