Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.utils import BaseOutput, is_scipy_available, logging | |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
| if is_scipy_available(): | |
| import scipy.stats | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): | |
| """ | |
| Output class for the scheduler's `step` function output. | |
| Args: | |
| prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | |
| Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | |
| denoising loop. | |
| """ | |
| latents: torch.FloatTensor | |
| log_probs: torch.FloatTensor | |
| class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): | |
| """ | |
| Euler scheduler. | |
| This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic | |
| methods the library implements for all schedulers such as loading and saving. | |
| Args: | |
| num_train_timesteps (`int`, defaults to 1000): | |
| The number of diffusion steps to train the model. | |
| shift (`float`, defaults to 1.0): | |
| The shift value for the timestep schedule. | |
| use_dynamic_shifting (`bool`, defaults to False): | |
| Whether to apply timestep shifting on-the-fly based on the image resolution. | |
| base_shift (`float`, defaults to 0.5): | |
| Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent | |
| with desired output. | |
| max_shift (`float`, defaults to 1.15): | |
| Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be | |
| more exaggerated or stylized. | |
| base_image_seq_len (`int`, defaults to 256): | |
| The base image sequence length. | |
| max_image_seq_len (`int`, defaults to 4096): | |
| The maximum image sequence length. | |
| invert_sigmas (`bool`, defaults to False): | |
| Whether to invert the sigmas. | |
| shift_terminal (`float`, defaults to None): | |
| The end value of the shifted timestep schedule. | |
| use_karras_sigmas (`bool`, defaults to False): | |
| Whether to use Karras sigmas for step sizes in the noise schedule during sampling. | |
| use_exponential_sigmas (`bool`, defaults to False): | |
| Whether to use exponential sigmas for step sizes in the noise schedule during sampling. | |
| use_beta_sigmas (`bool`, defaults to False): | |
| Whether to use beta sigmas for step sizes in the noise schedule during sampling. | |
| time_shift_type (`str`, defaults to "exponential"): | |
| The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". | |
| stochastic_sampling (`bool`, defaults to False): | |
| Whether to use stochastic sampling. | |
| """ | |
| _compatibles = [] | |
| order = 1 | |
| def __init__( | |
| self, | |
| num_train_timesteps: int = 1000, | |
| shift: float = 1.0, | |
| use_dynamic_shifting: bool = False, | |
| base_shift: Optional[float] = 0.5, | |
| max_shift: Optional[float] = 1.15, | |
| base_image_seq_len: Optional[int] = 256, | |
| max_image_seq_len: Optional[int] = 4096, | |
| invert_sigmas: bool = False, | |
| shift_terminal: Optional[float] = None, | |
| use_karras_sigmas: Optional[bool] = False, | |
| use_exponential_sigmas: Optional[bool] = False, | |
| use_beta_sigmas: Optional[bool] = False, | |
| time_shift_type: str = "exponential", | |
| stochastic_sampling: bool = True, | |
| ): | |
| if self.config.use_beta_sigmas and not is_scipy_available(): | |
| raise ImportError("Make sure to install scipy if you want to use beta sigmas.") | |
| if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: | |
| raise ValueError( | |
| "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." | |
| ) | |
| if time_shift_type not in {"exponential", "linear"}: | |
| raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") | |
| timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() | |
| timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) | |
| sigmas = timesteps / num_train_timesteps | |
| if not use_dynamic_shifting: | |
| # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution | |
| sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) | |
| self.timesteps = sigmas * num_train_timesteps | |
| self._step_index = None | |
| self._begin_index = None | |
| self._shift = shift | |
| self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication | |
| self.sigma_min = self.sigmas[-1].item() | |
| self.sigma_max = self.sigmas[0].item() | |
| def shift(self): | |
| """ | |
| The value used for shifting. | |
| """ | |
| return self._shift | |
| def step_index(self): | |
| """ | |
| The index counter for current timestep. It will increase 1 after each scheduler step. | |
| """ | |
| return self._step_index | |
| def begin_index(self): | |
| """ | |
| The index for the first timestep. It should be set from pipeline with `set_begin_index` method. | |
| """ | |
| return self._begin_index | |
| # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index | |
| def set_begin_index(self, begin_index: int = 0): | |
| """ | |
| Sets the begin index for the scheduler. This function should be run from pipeline before the inference. | |
| Args: | |
| begin_index (`int`): | |
| The begin index for the scheduler. | |
| """ | |
| self._begin_index = begin_index | |
| def set_shift(self, shift: float): | |
| self._shift = shift | |
| def scale_noise( | |
| self, | |
| sample: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| noise: Optional[torch.FloatTensor] = None, | |
| ) -> torch.FloatTensor: | |
| """ | |
| Forward process in flow-matching | |
| Args: | |
| sample (`torch.FloatTensor`): | |
| The input sample. | |
| timestep (`int`, *optional*): | |
| The current timestep in the diffusion chain. | |
| Returns: | |
| `torch.FloatTensor`: | |
| A scaled input sample. | |
| """ | |
| # Make sure sigmas and timesteps have the same device and dtype as original_samples | |
| sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) | |
| if sample.device.type == "mps" and torch.is_floating_point(timestep): | |
| # mps does not support float64 | |
| schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) | |
| timestep = timestep.to(sample.device, dtype=torch.float32) | |
| else: | |
| schedule_timesteps = self.timesteps.to(sample.device) | |
| timestep = timestep.to(sample.device) | |
| # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index | |
| if self.begin_index is None: | |
| step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] | |
| elif self.step_index is not None: | |
| # add_noise is called after first denoising step (for inpainting) | |
| step_indices = [self.step_index] * timestep.shape[0] | |
| else: | |
| # add noise is called before first denoising step to create initial latent(img2img) | |
| step_indices = [self.begin_index] * timestep.shape[0] | |
| sigma = sigmas[step_indices].flatten() | |
| while len(sigma.shape) < len(sample.shape): | |
| sigma = sigma.unsqueeze(-1) | |
| sample = sigma * noise + (1.0 - sigma) * sample | |
| return sample | |
| def _sigma_to_t(self, sigma): | |
| return sigma * self.config.num_train_timesteps | |
| def time_shift(self, mu: float, sigma: float, t: torch.Tensor): | |
| if self.config.time_shift_type == "exponential": | |
| return self._time_shift_exponential(mu, sigma, t) | |
| elif self.config.time_shift_type == "linear": | |
| return self._time_shift_linear(mu, sigma, t) | |
| def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: | |
| r""" | |
| Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config | |
| value. | |
| Reference: | |
| https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 | |
| Args: | |
| t (`torch.Tensor`): | |
| A tensor of timesteps to be stretched and shifted. | |
| Returns: | |
| `torch.Tensor`: | |
| A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. | |
| """ | |
| one_minus_z = 1 - t | |
| scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) | |
| stretched_t = 1 - (one_minus_z / scale_factor) | |
| return stretched_t | |
| def set_timesteps( | |
| self, | |
| num_inference_steps: Optional[int] = None, | |
| device: Union[str, torch.device] = None, | |
| sigmas: Optional[List[float]] = None, | |
| mu: Optional[float] = None, | |
| timesteps: Optional[List[float]] = None, | |
| ): | |
| """ | |
| Sets the discrete timesteps used for the diffusion chain (to be run before inference). | |
| Args: | |
| num_inference_steps (`int`, *optional*): | |
| The number of diffusion steps used when generating samples with a pre-trained model. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| sigmas (`List[float]`, *optional*): | |
| Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed | |
| automatically. | |
| mu (`float`, *optional*): | |
| Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep | |
| shifting. | |
| timesteps (`List[float]`, *optional*): | |
| Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed | |
| automatically. | |
| """ | |
| if self.config.use_dynamic_shifting and mu is None: | |
| raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") | |
| if sigmas is not None and timesteps is not None: | |
| if len(sigmas) != len(timesteps): | |
| raise ValueError("`sigmas` and `timesteps` should have the same length") | |
| if num_inference_steps is not None: | |
| if (sigmas is not None and len(sigmas) != num_inference_steps) or ( | |
| timesteps is not None and len(timesteps) != num_inference_steps | |
| ): | |
| raise ValueError( | |
| "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" | |
| ) | |
| else: | |
| num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) | |
| self.num_inference_steps = num_inference_steps | |
| # 1. Prepare default sigmas | |
| is_timesteps_provided = timesteps is not None | |
| if is_timesteps_provided: | |
| timesteps = np.array(timesteps).astype(np.float32) | |
| if sigmas is None: | |
| if timesteps is None: | |
| timesteps = np.linspace( | |
| self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps | |
| ) | |
| sigmas = timesteps / self.config.num_train_timesteps | |
| else: | |
| sigmas = np.array(sigmas).astype(np.float32) | |
| num_inference_steps = len(sigmas) | |
| # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of | |
| # "exponential" or "linear" type is applied | |
| if self.config.use_dynamic_shifting: | |
| sigmas = self.time_shift(mu, 1.0, sigmas) | |
| else: | |
| sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) | |
| # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value | |
| if self.config.shift_terminal: | |
| sigmas = self.stretch_shift_to_terminal(sigmas) | |
| # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules | |
| if self.config.use_karras_sigmas: | |
| sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | |
| elif self.config.use_exponential_sigmas: | |
| sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | |
| elif self.config.use_beta_sigmas: | |
| sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) | |
| # 5. Convert sigmas and timesteps to tensors and move to specified device | |
| sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) | |
| if not is_timesteps_provided: | |
| timesteps = sigmas * self.config.num_train_timesteps | |
| else: | |
| timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) | |
| # 6. Append the terminal sigma value. | |
| # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the | |
| # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi | |
| if self.config.invert_sigmas: | |
| sigmas = 1.0 - sigmas | |
| timesteps = sigmas * self.config.num_train_timesteps | |
| sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) | |
| else: | |
| sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) | |
| self.timesteps = timesteps | |
| self.sigmas = sigmas | |
| self._step_index = None | |
| self._begin_index = None | |
| # def index_for_timestep(self, timestep, schedule_timesteps=None): | |
| # if schedule_timesteps is None: | |
| # schedule_timesteps = self.timesteps | |
| # indices = (schedule_timesteps == timestep).nonzero() | |
| # # The sigma index that is taken for the **very** first `step` | |
| # # is always the second index (or the last index if there is only 1) | |
| # # This way we can ensure we don't accidentally skip a sigma in | |
| # # case we start in the middle of the denoising schedule (e.g. for image-to-image) | |
| # pos = 1 if len(indices) > 1 else 0 | |
| # return indices[pos].item() | |
| def index_for_timestep(self, timestep, schedule_timesteps=None): | |
| if schedule_timesteps is None: | |
| schedule_timesteps = self.timesteps | |
| match = (schedule_timesteps[None, :] == timestep[:, None]) | |
| cols = torch.arange(schedule_timesteps.numel()) | |
| cols = cols.expand(timestep.numel(), -1) | |
| match=match.to(cols.device) | |
| idx_last = torch.where(match, cols, torch.full_like(cols, -1)).max(dim=1).values | |
| return idx_last | |
| def _init_step_index(self, timestep): | |
| if self.begin_index is None: | |
| if isinstance(timestep, torch.Tensor): | |
| timestep = timestep.to(self.timesteps.device) | |
| self._step_index = self.index_for_timestep(timestep) | |
| else: | |
| self._step_index = self._begin_index | |
| def step( | |
| self, | |
| model_output: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| sample: torch.FloatTensor, | |
| s_churn: float = 0.0, | |
| s_tmin: float = 0.0, | |
| s_tmax: float = float("inf"), | |
| s_noise: float = 1.0, | |
| generator: Optional[torch.Generator] = None, | |
| prev_sample: Optional[torch.FloatTensor] = None, | |
| per_token_timesteps: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| init_step: bool = False, | |
| ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: | |
| """ | |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | |
| process from the learned model outputs (most often the predicted noise). | |
| Args: | |
| model_output (`torch.FloatTensor`): | |
| The direct output from learned diffusion model. | |
| timestep (`float`): | |
| The current discrete timestep in the diffusion chain. | |
| sample (`torch.FloatTensor`): | |
| A current instance of a sample created by the diffusion process. | |
| s_churn (`float`): | |
| s_tmin (`float`): | |
| s_tmax (`float`): | |
| s_noise (`float`, defaults to 1.0): | |
| Scaling factor for noise added to the sample. | |
| generator (`torch.Generator`, *optional*): | |
| A random number generator. | |
| per_token_timesteps (`torch.Tensor`, *optional*): | |
| The timesteps for each token in the sample. | |
| return_dict (`bool`): | |
| Whether or not to return a | |
| [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. | |
| Returns: | |
| [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: | |
| If return_dict is `True`, | |
| [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, | |
| otherwise a tuple is returned where the first element is the sample tensor. | |
| """ | |
| if ( | |
| isinstance(timestep, int) | |
| or isinstance(timestep, torch.IntTensor) | |
| or isinstance(timestep, torch.LongTensor) | |
| ): | |
| raise ValueError( | |
| ( | |
| "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | |
| " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" | |
| " one of the `scheduler.timesteps` as a timestep." | |
| ), | |
| ) | |
| if self.step_index is None: | |
| self._init_step_index(timestep) | |
| # if init_step: | |
| # self._init_step_index(timestep) | |
| # Upcast to avoid precision issues when computing prev_sample | |
| # sample = sample.to(torch.float32) | |
| if per_token_timesteps is not None: | |
| per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps | |
| sigmas = self.sigmas[:, None, None] | |
| lower_mask = sigmas < per_token_sigmas[None] - 1e-6 | |
| lower_sigmas = lower_mask * sigmas | |
| lower_sigmas, _ = lower_sigmas.max(dim=0) | |
| current_sigma = per_token_sigmas[..., None] | |
| next_sigma = lower_sigmas[..., None] | |
| dt = current_sigma - next_sigma | |
| else: | |
| sigma_idx = self.step_index | |
| if init_step: | |
| sigma_idx = self.index_for_timestep(timestep) | |
| sigma = self.sigmas[sigma_idx] | |
| sigma_next = self.sigmas[sigma_idx + 1] | |
| current_sigma = sigma | |
| next_sigma = sigma_next | |
| if len(current_sigma.shape)> 0: | |
| current_sigma = current_sigma[:, None, None] | |
| next_sigma = next_sigma[:, None, None] | |
| dt = sigma_next - sigma | |
| log_prob = None | |
| if self.config.stochastic_sampling: | |
| # if len(current_sigma.shape)> 0: | |
| # current_sigma = current_sigma[:, None, None] | |
| # print(f"model_output {model_output.shape}") | |
| # print(f"sigma {current_sigma.shape}") | |
| # print(f"sample {sample.shape}") | |
| x0 = sample - current_sigma * model_output | |
| if prev_sample is None: | |
| if generator is None: | |
| generator = torch.Generator(device=sample.device) | |
| generator.seed() | |
| noise = torch.randn(sample.size(), generator=generator, device=sample.device, dtype=sample.dtype) | |
| prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise | |
| prev_sample_mean = (1-next_sigma)*x0 | |
| prev_sample_ =prev_sample.clone() | |
| diff = prev_sample_.detach() - prev_sample_mean | |
| log_prob = - (diff**2) / (2 * (next_sigma+1e-7)**2) \ | |
| - torch.log(next_sigma) \ | |
| - 0.5 * torch.log(torch.tensor(2 * np.pi)) | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | |
| else: | |
| prev_sample = sample + dt * model_output | |
| # upon completion increase step index by one | |
| self._step_index += 1 | |
| if per_token_timesteps is None: | |
| # Cast sample back to model compatible dtype | |
| prev_sample = prev_sample.to(model_output.dtype) | |
| # if log_prob == None: | |
| # raise ValueError("log_prob is None, stochastic_sampling is off") | |
| if not return_dict: | |
| return (prev_sample,log_prob) | |
| return FlowMatchEulerDiscreteSchedulerOutput(latents=prev_sample, log_probs=log_prob) | |
| # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras | |
| def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: | |
| """Constructs the noise schedule of Karras et al. (2022).""" | |
| # Hack to make sure that other schedulers which copy this function don't break | |
| # TODO: Add this logic to the other schedulers | |
| if hasattr(self.config, "sigma_min"): | |
| sigma_min = self.config.sigma_min | |
| else: | |
| sigma_min = None | |
| if hasattr(self.config, "sigma_max"): | |
| sigma_max = self.config.sigma_max | |
| else: | |
| sigma_max = None | |
| sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() | |
| sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() | |
| rho = 7.0 # 7.0 is the value used in the paper | |
| ramp = np.linspace(0, 1, num_inference_steps) | |
| min_inv_rho = sigma_min ** (1 / rho) | |
| max_inv_rho = sigma_max ** (1 / rho) | |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | |
| return sigmas | |
| # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential | |
| def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: | |
| """Constructs an exponential noise schedule.""" | |
| # Hack to make sure that other schedulers which copy this function don't break | |
| # TODO: Add this logic to the other schedulers | |
| if hasattr(self.config, "sigma_min"): | |
| sigma_min = self.config.sigma_min | |
| else: | |
| sigma_min = None | |
| if hasattr(self.config, "sigma_max"): | |
| sigma_max = self.config.sigma_max | |
| else: | |
| sigma_max = None | |
| sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() | |
| sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() | |
| sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) | |
| return sigmas | |
| # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta | |
| def _convert_to_beta( | |
| self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 | |
| ) -> torch.Tensor: | |
| """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" | |
| # Hack to make sure that other schedulers which copy this function don't break | |
| # TODO: Add this logic to the other schedulers | |
| if hasattr(self.config, "sigma_min"): | |
| sigma_min = self.config.sigma_min | |
| else: | |
| sigma_min = None | |
| if hasattr(self.config, "sigma_max"): | |
| sigma_max = self.config.sigma_max | |
| else: | |
| sigma_max = None | |
| sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() | |
| sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() | |
| sigmas = np.array( | |
| [ | |
| sigma_min + (ppf * (sigma_max - sigma_min)) | |
| for ppf in [ | |
| scipy.stats.beta.ppf(timestep, alpha, beta) | |
| for timestep in 1 - np.linspace(0, 1, num_inference_steps) | |
| ] | |
| ] | |
| ) | |
| return sigmas | |
| def _time_shift_exponential(self, mu, sigma, t): | |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
| def _time_shift_linear(self, mu, sigma, t): | |
| return mu / (mu + (1 / t - 1) ** sigma) | |
| def __len__(self): | |
| return self.config.num_train_timesteps | |