| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | from einops import rearrange |
| |
|
| | from .ar_tokenizer_quantizers import FSQuantizer |
| |
|
| | |
| | |
| | torch._C._jit_set_texpr_fuser_enabled(False) |
| |
|
| |
|
| | def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: |
| | """Loads a torch.jit.ScriptModule from a filepath. |
| | |
| | Args: |
| | jit_filepath: The filepath to the JIT-compiled model. |
| | device: The device to load the model onto, default=cuda. |
| | Returns: |
| | The JIT compiled model loaded to device and on eval mode. |
| | """ |
| | |
| | |
| | torch._C._jit_set_texpr_fuser_enabled(False) |
| |
|
| | model = torch.jit.load(jit_filepath) |
| | return model.eval().to(device) |
| |
|
| |
|
| | class BaseDiscreteVideoFSQTokenizer(torch.nn.Module): |
| | """ |
| | A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization |
| | using provided mean and standard deviation values for latent space representation. |
| | Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes. |
| | |
| | Attributes: |
| | encoder (Module | Callable): Encoder loaded from storage. |
| | decoder (Module | Callable): Decoder loaded from storage. |
| | dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. |
| | |
| | Args: |
| | name (str): Name of the model, used for differentiating cache file paths. |
| | latent_ch (int, optional): Number of latent channels (default is 6). |
| | is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). |
| | pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. |
| | latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. |
| | max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. |
| | level (list[int]): The level defined in FSQ quantizer. |
| | compression_ratio (list[int]): The compression factor for (T, H, W). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | name: str, |
| | latent_ch: int = 6, |
| | is_bf16: bool = True, |
| | pixel_chunk_duration: int = 25, |
| | latent_chunk_duration: int = 4, |
| | max_enc_batch_size: int = 8, |
| | max_dec_batch_size: int = 4, |
| | levels: list[int] = [8, 8, 8, 5, 5, 5], |
| | compression_ratio: list[int] = [8, 16, 16], |
| | ): |
| | super().__init__() |
| | self.channel = latent_ch |
| | self.name = name |
| | dtype = torch.bfloat16 if is_bf16 else torch.float32 |
| | self.dtype = dtype |
| | self.pixel_chunk_duration = pixel_chunk_duration |
| | self.latent_chunk_duration = latent_chunk_duration |
| | self.max_enc_batch_size = max_enc_batch_size |
| | self.max_dec_batch_size = max_dec_batch_size |
| | self.levels = levels |
| | self.compress_ratio = compression_ratio |
| | self.fsq_quantizer = FSQuantizer(levels) |
| |
|
| | @property |
| | def latent_ch(self) -> int: |
| | """ |
| | Returns the number of latent channels in the tokenizer. |
| | """ |
| | return self.channel |
| |
|
| | @torch.no_grad() |
| | def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: |
| | B, C, T, H, W = state.shape |
| | if pixel_chunk_duration is None: |
| | |
| | pixel_chunk_duration = self.pixel_chunk_duration |
| | latent_chunk_duration = self.latent_chunk_duration |
| | else: |
| | |
| | latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] |
| |
|
| | assert ( |
| | T % pixel_chunk_duration == 0 |
| | ), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}" |
| | state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration) |
| |
|
| | |
| | if state.shape[0] > self.max_enc_batch_size: |
| | quantized_out_list = [] |
| | indices_list = [] |
| | for i in range(0, state.shape[0], self.max_enc_batch_size): |
| | indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype)) |
| | quantized_out_list.append(quantized_out) |
| | indices_list.append(indices) |
| | quantized_out = torch.cat(quantized_out_list, dim=0) |
| | indices = torch.cat(indices_list, dim=0) |
| | else: |
| | indices, quantized_out, _ = self.encoder(state.to(self.dtype)) |
| | assert quantized_out.shape[2] == latent_chunk_duration |
| | return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange( |
| | indices, "(b n) t h w -> b (n t) h w", b=B |
| | ) |
| |
|
| | @torch.no_grad() |
| | def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: |
| | B, T, _, _ = indices.shape |
| | if pixel_chunk_duration is None: |
| | pixel_chunk_duration = self.pixel_chunk_duration |
| | latent_chunk_duration = self.latent_chunk_duration |
| | else: |
| | latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] |
| | assert ( |
| | T % latent_chunk_duration == 0 |
| | ), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}" |
| | indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration) |
| |
|
| | |
| | if indices.shape[0] > self.max_dec_batch_size: |
| | state = [] |
| | for i in range(0, indices.shape[0], self.max_dec_batch_size): |
| | state.append(self.decoder(indices[i : i + self.max_dec_batch_size])) |
| | state = torch.cat(state, dim=0) |
| | else: |
| | state = self.decoder(indices) |
| |
|
| | assert state.shape[2] == pixel_chunk_duration |
| | return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) |
| |
|
| | def reset_dtype(self, *args, **kwargs): |
| | """ |
| | Resets the data type of the encoder and decoder to the model's default data type. |
| | |
| | Args: |
| | *args, **kwargs: Unused, present to allow flexibility in method calls. |
| | """ |
| | del args, kwargs |
| | self.decoder.to(self.dtype) |
| | self.encoder.to(self.dtype) |
| |
|
| |
|
| | class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer): |
| | """ |
| | A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder |
| | and decoder components from a remote store, handles data type conversions, and normalization |
| | using provided mean and standard deviation values for latent space representation. |
| | |
| | Attributes: |
| | encoder (Module): The JIT compiled encoder loaded from storage. |
| | decoder (Module): The JIT compiled decoder loaded from storage. |
| | dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. |
| | |
| | Args: |
| | enc_fp (str): File path to the encoder's JIT file on the remote store. |
| | dec_fp (str): File path to the decoder's JIT file on the remote store. |
| | name (str): Name of the model, used for differentiating cache file paths. |
| | latent_ch (int, optional): Number of latent channels (default is 6). |
| | is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). |
| | pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. |
| | latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. |
| | max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. |
| | level (list[int]): The level defined in FSQ quantizer. |
| | compression_ratio (list[int]): The compression factor for (T, H, W). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | enc_fp: str, |
| | dec_fp: str, |
| | name: str, |
| | latent_ch: int = 6, |
| | is_bf16: bool = True, |
| | pixel_chunk_duration: int = 25, |
| | latent_chunk_duration: int = 4, |
| | max_enc_batch_size: int = 8, |
| | max_dec_batch_size: int = 4, |
| | levels: list[int] = [8, 8, 8, 5, 5, 5], |
| | compression_ratio: list[int] = [8, 16, 16], |
| | ): |
| | super().__init__( |
| | name, |
| | latent_ch, |
| | is_bf16, |
| | pixel_chunk_duration, |
| | latent_chunk_duration, |
| | max_enc_batch_size, |
| | max_dec_batch_size, |
| | levels, |
| | compression_ratio, |
| | ) |
| |
|
| | self.load_encoder(enc_fp) |
| | self.load_decoder(dec_fp) |
| |
|
| | def load_encoder(self, enc_fp: str) -> None: |
| | """ |
| | Load the encoder from the remote store. |
| | |
| | Args: |
| | - enc_fp (str): File path to the encoder's JIT file on the remote store. |
| | """ |
| | self.encoder = load_jit_model(enc_fp, device="cuda") |
| | self.encoder.eval() |
| | for param in self.encoder.parameters(): |
| | param.requires_grad = False |
| | self.encoder.to(self.dtype) |
| |
|
| | def load_decoder(self, dec_fp: str) -> None: |
| | """ |
| | Load the decoder from the remote store. |
| | |
| | Args: |
| | - dec_fp (str): File path to the decoder's JIT file on the remote store. |
| | """ |
| | self.decoder = load_jit_model(dec_fp, device="cuda") |
| | self.decoder.eval() |
| | for param in self.decoder.parameters(): |
| | param.requires_grad = False |
| | self.decoder.to(self.dtype) |
| |
|
| |
|
| | class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer): |
| | """ |
| | A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder |
| | into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled, |
| | handles data type conversions, and normalization using provided mean and standard deviation values for latent |
| | space representation. |
| | |
| | Attributes: |
| | tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints |
| | encoder (Callable): tokenizer_module's encode method |
| | decoder (Callable): tokenizer_module's decode method |
| | dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. |
| | |
| | Args: |
| | enc_fp (str): File path to the encoder's JIT file on the remote store. |
| | dec_fp (str): File path to the decoder's JIT file on the remote store. |
| | tokenizer_module (Module): Tokenizer module that will have it's weights loaded |
| | name (str): Name of the model, used for differentiating cache file paths. |
| | latent_ch (int, optional): Number of latent channels (default is 6). |
| | is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). |
| | pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. |
| | latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. |
| | max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. |
| | level (list[int]): The level defined in FSQ quantizer. |
| | compression_ratio (list[int]): The compression factor for (T, H, W). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | enc_fp: str, |
| | dec_fp: str, |
| | tokenizer_module: torch.nn.Module, |
| | name: str, |
| | latent_ch: int = 6, |
| | is_bf16: bool = True, |
| | pixel_chunk_duration: int = 25, |
| | latent_chunk_duration: int = 4, |
| | max_enc_batch_size: int = 8, |
| | max_dec_batch_size: int = 4, |
| | levels: list[int] = [8, 8, 8, 5, 5, 5], |
| | compression_ratio: list[int] = [8, 16, 16], |
| | ): |
| | super().__init__( |
| | name, |
| | latent_ch, |
| | is_bf16, |
| | pixel_chunk_duration, |
| | latent_chunk_duration, |
| | max_enc_batch_size, |
| | max_dec_batch_size, |
| | levels, |
| | compression_ratio, |
| | ) |
| |
|
| | self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module) |
| |
|
| | def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None: |
| | """ |
| | Load the encoder from the remote store. |
| | |
| | Args: |
| | - enc_fp (str): File path to the encoder's JIT file on the remote store. |
| | - def_fp (str): File path to the decoder's JIT file on the remote store. |
| | - tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints |
| | """ |
| | self.decoder = load_jit_model(dec_fp) |
| |
|
| | self.decoder.eval() |
| | for param in self.decoder.parameters(): |
| | param.requires_grad = False |
| | self.decoder.to(self.dtype) |
| |
|
| | encoder_sd = load_jit_model(enc_fp).state_dict() |
| |
|
| | del tokenizer_module.post_quant_conv |
| | del tokenizer_module.decoder |
| |
|
| | state_dict = { |
| | k: v |
| | for k, v in (encoder_sd).items() |
| | |
| | if k |
| | not in ( |
| | "encoder.patcher3d.wavelets", |
| | "encoder.patcher3d._arange", |
| | "encoder.patcher3d.patch_size_buffer", |
| | "quantizer._levels", |
| | "quantizer._basis", |
| | "quantizer.implicit_codebook", |
| | ) |
| | } |
| |
|
| | tokenizer_module.load_state_dict(state_dict) |
| |
|
| | tokenizer_module.eval() |
| | for param in tokenizer_module.parameters(): |
| | param.requires_grad = False |
| | tokenizer_module.to(self.dtype) |
| |
|
| | self.tokenizer_module = tokenizer_module |
| | self.encoder = self.tokenizer_module.encode |
| |
|
| | def reset_dtype(self, *args, **kwargs): |
| | """ |
| | Resets the data type of the encoder and decoder to the model's default data type. |
| | |
| | Args: |
| | *args, **kwargs: Unused, present to allow flexibility in method calls. |
| | """ |
| | del args, kwargs |
| | self.decoder.to(self.dtype) |
| | self.tokenizer_module.to(self.dtype) |
| |
|