| | from __future__ import annotations |
| |
|
| | """sedd_wrapper.py |
| | ========================================= |
| | This module provides a minimal HuggingFace-compatible wrapper around the |
| | `SEDD` architecture that is implemented in :pyfile:`model/transformer.py`. |
| | |
| | The wrapper closely follows the design used in the Aero implementation that |
| | lives in this code-base (see :pyfile:`configuration_aero.py` and |
| | :pyfile:`modeling_aero.py`). Concretely we expose three public objects: |
| | |
| | * ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that |
| | stores the hyper-parameters needed to instantiate a ``SEDD`` model. |
| | * ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that |
| | internally contains an instance of the original ``SEDD`` network and maps |
| | from ``input_ids`` + ``sigma`` to the vocabulary logits. |
| | * ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput` |
| | dataclass that mirrors the usual "logits / loss" structure. |
| | |
| | With this wrapper a trained model checkpoint can be pushed to / loaded from |
| | 🤗 Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same |
| | way as any other ``transformers`` model. |
| | """ |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, List, Dict, Any, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.modeling_outputs import ModelOutput |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.utils import logging |
| |
|
| | |
| | from model.transformer import SEDD as _OrigSEDD |
| |
|
| | try: |
| | from omegaconf import OmegaConf |
| | except ImportError: |
| | OmegaConf = None |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class SEDDConfig(PretrainedConfig): |
| | """Configuration class for the SEDD architecture. |
| | |
| | The defaults reproduce *roughly* the "small" configuration shipped in |
| | ``configs/model/small.yaml``. Additional keys that are present in the |
| | original Hydra config but not required for instantiation (e.g. *training* |
| | hyper-parameters) are deliberately omitted here – they can still be stored |
| | as *extra* fields in the underlying JSON if a user wishes to preserve them. |
| | """ |
| |
|
| | model_type: str = "sedd" |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | tokens: int = 50257, |
| | |
| | graph_type: str = "absorb", |
| | |
| | model_hidden_size: int = 768, |
| | model_cond_dim: int = 128, |
| | model_length: int = 1024, |
| | model_n_blocks: int = 12, |
| | model_n_heads: int = 12, |
| | model_scale_by_sigma: bool = True, |
| | model_dropout: float = 0.10, |
| | |
| | tie_word_embeddings: bool = False, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
| |
|
| | |
| | self.tokens = tokens |
| | self.graph_type = graph_type |
| |
|
| | |
| | self.model_hidden_size = model_hidden_size |
| | self.model_cond_dim = model_cond_dim |
| | self.model_length = model_length |
| | self.model_n_blocks = model_n_blocks |
| | self.model_n_heads = model_n_heads |
| | self.model_scale_by_sigma = model_scale_by_sigma |
| | self.model_dropout = model_dropout |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def to_hydra(self): |
| | """Convert this *flat* config to the nested OmegaConf structure that |
| | the reference ``SEDD`` implementation expects. |
| | """ |
| |
|
| | if OmegaConf is None: |
| | raise RuntimeError("`omegaconf` is required to build a Hydra config") |
| |
|
| | nested: Dict[str, Any] = { |
| | "tokens": self.tokens, |
| | "graph": { |
| | "type": self.graph_type, |
| | }, |
| | "model": { |
| | "hidden_size": self.model_hidden_size, |
| | "cond_dim": self.model_cond_dim, |
| | "length": self.model_length, |
| | "n_blocks": self.model_n_blocks, |
| | "n_heads": self.model_n_heads, |
| | "scale_by_sigma": self.model_scale_by_sigma, |
| | "dropout": self.model_dropout, |
| | }, |
| | } |
| | return OmegaConf.create(nested) |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class SEDDOutput(ModelOutput): |
| | """Standard output for :class:`SEDDModel`. |
| | |
| | Attributes |
| | ---------- |
| | loss: |
| | *Optional* scalar returned when ``labels`` are provided. |
| | logits: |
| | The raw vocabulary logits computed by the model of shape |
| | ``(batch_size, sequence_length, vocab_size)``. |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | logits: torch.FloatTensor | None = None |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class SEDDModel(PreTrainedModel): |
| | """HuggingFace *Transformers* wrapper around the original ``SEDD`` model.""" |
| |
|
| | config_class = SEDDConfig |
| | base_model_prefix = "score_model" |
| | _no_split_modules: List[str] = [ |
| | "DDiTBlock", |
| | ] |
| |
|
| | def __init__(self, config: SEDDConfig): |
| | super().__init__(config) |
| |
|
| | |
| | |
| | |
| | |
| | if OmegaConf is None: |
| | raise RuntimeError("`omegaconf` is required to instantiate SEDD") |
| |
|
| | hydra_cfg = config.to_hydra() |
| | self.score_model = _OrigSEDD(hydra_cfg) |
| |
|
| | |
| | self.post_init() |
| |
|
| | |
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | sigma: torch.FloatTensor, |
| | labels: Optional[torch.LongTensor] = None, |
| | **kwargs: Any, |
| | ) -> Union[SEDDOutput, Tuple]: |
| | """Run a forward pass. |
| | |
| | Parameters |
| | ---------- |
| | input_ids: |
| | Token indices of shape ``(batch_size, seq_len)``. |
| | sigma: |
| | Noise level ("time-step") of shape ``(batch_size,)``. |
| | labels: |
| | *Optional* label tensor used to compute a cross-entropy training |
| | loss. If provided the returned :class:`SEDDOutput` will contain a |
| | ``loss`` field. |
| | """ |
| |
|
| | logits = self.score_model(indices=input_ids, sigma=sigma) |
| |
|
| | loss: Optional[torch.Tensor] = None |
| | if labels is not None: |
| | |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| |
|
| | if not self.config.return_dict: |
| | output: Tuple[Any, ...] = (logits,) |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SEDDOutput(loss=loss, logits=logits) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: str, |
| | *model_args: Any, |
| | **kwargs: Any, |
| | ) -> "SEDDModel": |
| | """Overrides the default method to allow loading legacy SEDD checkpoints |
| | whose weights are saved via ``torch.save({'model': state_dict, ...})``. |
| | """ |
| |
|
| | try: |
| | |
| | |
| | |
| | return super().from_pretrained( |
| | pretrained_model_name_or_path, *model_args, **kwargs |
| | ) |
| | except (EnvironmentError, RuntimeError) as e: |
| | logger.info( |
| | "Falling back to legacy SEDD checkpoint format because standard " |
| | "loading raised: %s", e, |
| | ) |
| |
|
| | |
| | |
| | |
| | config = kwargs.pop("config", None) or SEDDConfig.from_pretrained( |
| | pretrained_model_name_or_path |
| | ) |
| | model = cls(config, *model_args, **kwargs) |
| |
|
| | |
| | |
| | |
| | import os |
| | import torch as _torch |
| |
|
| | checkpoint_path = os.path.join( |
| | pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth" |
| | ) |
| | if not os.path.isfile(checkpoint_path): |
| | raise FileNotFoundError( |
| | "Could not find legacy SEDD checkpoint at " f"{checkpoint_path}" |
| | ) |
| |
|
| | ckpt = _torch.load(checkpoint_path, map_location="cpu") |
| | state_dict = ckpt.get("model", ckpt) |
| | |
| | state_dict = { |
| | k.replace("module.", ""): v for k, v in state_dict.items() |
| | } |
| | missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| | if missing: |
| | logger.warning("Missing keys when loading SEDD weights: %s", missing) |
| | if unexpected: |
| | logger.warning( |
| | "Unexpected keys when loading SEDD weights: %s", unexpected |
| | ) |
| | return model |
| |
|
| | |
| | |
| | |
| |
|
| | __all__ = [ |
| | "SEDDConfig", |
| | "SEDDModel", |
| | "SEDDOutput", |
| | ] |
| |
|