Fix SDE inference producing noise when shift != 1.0

#3

The SDE inference path computes next_timestep using a linear formula:

next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)

This ignores the shift transformation applied to the timestep schedule:

t = shift * t / (1 + (shift - 1) * t)

When shift != 1.0, the renoise step uses a timestep that doesn't match the shifted schedule the model expects, causing accumulated error and noise-only output.

t_prev is already available from the loop iterator and includes the shift transformation. The fix replaces the linear calculation with t_prev:

Before:

next_timestep = 1.0 - (float(step_idx + 1) / infer_steps)
xt = self.renoise(pred_clean, next_timestep)

After:

xt = self.renoise(pred_clean, t_prev)

When shift == 1.0, the behavior is identical to the current code. When shift != 1.0, the schedule is now correct.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment