sam-motamed commited on
Commit
c6e4b0e
·
verified ·
1 Parent(s): a3d47af

Add diffusers/cogvideox_transformer3d.py

Browse files
Files changed (1) hide show
  1. diffusers/cogvideox_transformer3d.py +845 -0
diffusers/cogvideox_transformer3d.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import Attention, FeedForward
25
+ from diffusers.models.attention_processor import (
26
+ AttentionProcessor, CogVideoXAttnProcessor2_0,
27
+ FusedCogVideoXAttnProcessor2_0)
28
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
29
+ TimestepEmbedding, Timesteps,
30
+ get_3d_sincos_pos_embed)
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
34
+ from diffusers.utils import is_torch_version, logging
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from torch import nn
37
+
38
+ from dist_utils import (get_sequence_parallel_rank,
39
+ get_sequence_parallel_world_size,
40
+ get_sp_group,
41
+ xFuserLongContextAttention)
42
+ from dist_utils import CogVideoXMultiGPUsAttnProcessor2_0
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ class CogVideoXPatchEmbed(nn.Module):
49
+ def __init__(
50
+ self,
51
+ patch_size: int = 2,
52
+ patch_size_t: Optional[int] = None,
53
+ in_channels: int = 16,
54
+ embed_dim: int = 1920,
55
+ text_embed_dim: int = 4096,
56
+ bias: bool = True,
57
+ sample_width: int = 90,
58
+ sample_height: int = 60,
59
+ sample_frames: int = 49,
60
+ temporal_compression_ratio: int = 4,
61
+ max_text_seq_length: int = 226,
62
+ spatial_interpolation_scale: float = 1.875,
63
+ temporal_interpolation_scale: float = 1.0,
64
+ use_positional_embeddings: bool = True,
65
+ use_learned_positional_embeddings: bool = True,
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ post_patch_height = sample_height // patch_size
70
+ post_patch_width = sample_width // patch_size
71
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
72
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
73
+ self.post_patch_height = post_patch_height
74
+ self.post_patch_width = post_patch_width
75
+ self.post_time_compression_frames = post_time_compression_frames
76
+ self.patch_size = patch_size
77
+ self.patch_size_t = patch_size_t
78
+ self.embed_dim = embed_dim
79
+ self.sample_height = sample_height
80
+ self.sample_width = sample_width
81
+ self.sample_frames = sample_frames
82
+ self.temporal_compression_ratio = temporal_compression_ratio
83
+ self.max_text_seq_length = max_text_seq_length
84
+ self.spatial_interpolation_scale = spatial_interpolation_scale
85
+ self.temporal_interpolation_scale = temporal_interpolation_scale
86
+ self.use_positional_embeddings = use_positional_embeddings
87
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
88
+
89
+ if patch_size_t is None:
90
+ # CogVideoX 1.0 checkpoints
91
+ self.proj = nn.Conv2d(
92
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
93
+ )
94
+ else:
95
+ # CogVideoX 1.5 checkpoints
96
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
97
+
98
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
99
+
100
+ if use_positional_embeddings or use_learned_positional_embeddings:
101
+ persistent = use_learned_positional_embeddings
102
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
103
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
104
+
105
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
106
+ post_patch_height = sample_height // self.patch_size
107
+ post_patch_width = sample_width // self.patch_size
108
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
109
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
110
+
111
+ pos_embedding = get_3d_sincos_pos_embed(
112
+ self.embed_dim,
113
+ (post_patch_width, post_patch_height),
114
+ post_time_compression_frames,
115
+ self.spatial_interpolation_scale,
116
+ self.temporal_interpolation_scale,
117
+ )
118
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
119
+ joint_pos_embedding = torch.zeros(
120
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
121
+ )
122
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
123
+
124
+ return joint_pos_embedding
125
+
126
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
127
+ r"""
128
+ Args:
129
+ text_embeds (`torch.Tensor`):
130
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
131
+ image_embeds (`torch.Tensor`):
132
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
133
+ """
134
+ text_embeds = self.text_proj(text_embeds)
135
+
136
+ text_batch_size, text_seq_length, text_channels = text_embeds.shape
137
+ batch_size, num_frames, channels, height, width = image_embeds.shape
138
+
139
+ if self.patch_size_t is None:
140
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
141
+ image_embeds = self.proj(image_embeds)
142
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
143
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
144
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
145
+ else:
146
+ p = self.patch_size
147
+ p_t = self.patch_size_t
148
+
149
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
150
+ # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
151
+ image_embeds = image_embeds.reshape(
152
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
153
+ )
154
+ # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
155
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
156
+ image_embeds = self.proj(image_embeds)
157
+
158
+ embeds = torch.cat(
159
+ [text_embeds, image_embeds], dim=1
160
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
161
+
162
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
163
+ seq_length = height * width * num_frames // (self.patch_size**2)
164
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
165
+ pos_embeds = self.pos_embedding
166
+ emb_size = embeds.size()[-1]
167
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
168
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
169
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
170
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
171
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
172
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
173
+ embeds = embeds + pos_embeds
174
+
175
+ return embeds
176
+
177
+ @maybe_allow_in_graph
178
+ class CogVideoXBlock(nn.Module):
179
+ r"""
180
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
181
+
182
+ Parameters:
183
+ dim (`int`):
184
+ The number of channels in the input and output.
185
+ num_attention_heads (`int`):
186
+ The number of heads to use for multi-head attention.
187
+ attention_head_dim (`int`):
188
+ The number of channels in each head.
189
+ time_embed_dim (`int`):
190
+ The number of channels in timestep embedding.
191
+ dropout (`float`, defaults to `0.0`):
192
+ The dropout probability to use.
193
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
194
+ Activation function to be used in feed-forward.
195
+ attention_bias (`bool`, defaults to `False`):
196
+ Whether or not to use bias in attention projection layers.
197
+ qk_norm (`bool`, defaults to `True`):
198
+ Whether or not to use normalization after query and key projections in Attention.
199
+ norm_elementwise_affine (`bool`, defaults to `True`):
200
+ Whether to use learnable elementwise affine parameters for normalization.
201
+ norm_eps (`float`, defaults to `1e-5`):
202
+ Epsilon value for normalization layers.
203
+ final_dropout (`bool` defaults to `False`):
204
+ Whether to apply a final dropout after the last feed-forward layer.
205
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
206
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
207
+ ff_bias (`bool`, defaults to `True`):
208
+ Whether or not to use bias in Feed-forward layer.
209
+ attention_out_bias (`bool`, defaults to `True`):
210
+ Whether or not to use bias in Attention output projection layer.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ dim: int,
216
+ num_attention_heads: int,
217
+ attention_head_dim: int,
218
+ time_embed_dim: int,
219
+ dropout: float = 0.0,
220
+ activation_fn: str = "gelu-approximate",
221
+ attention_bias: bool = False,
222
+ qk_norm: bool = True,
223
+ norm_elementwise_affine: bool = True,
224
+ norm_eps: float = 1e-5,
225
+ final_dropout: bool = True,
226
+ ff_inner_dim: Optional[int] = None,
227
+ ff_bias: bool = True,
228
+ attention_out_bias: bool = True,
229
+ ):
230
+ super().__init__()
231
+
232
+ # 1. Self Attention
233
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
234
+
235
+ self.attn1 = Attention(
236
+ query_dim=dim,
237
+ dim_head=attention_head_dim,
238
+ heads=num_attention_heads,
239
+ qk_norm="layer_norm" if qk_norm else None,
240
+ eps=1e-6,
241
+ bias=attention_bias,
242
+ out_bias=attention_out_bias,
243
+ processor=CogVideoXAttnProcessor2_0(),
244
+ )
245
+
246
+ # 2. Feed Forward
247
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
248
+
249
+ self.ff = FeedForward(
250
+ dim,
251
+ dropout=dropout,
252
+ activation_fn=activation_fn,
253
+ final_dropout=final_dropout,
254
+ inner_dim=ff_inner_dim,
255
+ bias=ff_bias,
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ encoder_hidden_states: torch.Tensor,
262
+ temb: torch.Tensor,
263
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
264
+ ) -> torch.Tensor:
265
+ text_seq_length = encoder_hidden_states.size(1)
266
+
267
+ # norm & modulate
268
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
269
+ hidden_states, encoder_hidden_states, temb
270
+ )
271
+
272
+ # attention
273
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
274
+ hidden_states=norm_hidden_states,
275
+ encoder_hidden_states=norm_encoder_hidden_states,
276
+ image_rotary_emb=image_rotary_emb,
277
+ )
278
+
279
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
280
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
281
+
282
+ # norm & modulate
283
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
284
+ hidden_states, encoder_hidden_states, temb
285
+ )
286
+
287
+ # feed-forward
288
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
289
+ ff_output = self.ff(norm_hidden_states)
290
+
291
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
292
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
293
+
294
+ return hidden_states, encoder_hidden_states
295
+
296
+
297
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
298
+ """
299
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
300
+
301
+ Parameters:
302
+ num_attention_heads (`int`, defaults to `30`):
303
+ The number of heads to use for multi-head attention.
304
+ attention_head_dim (`int`, defaults to `64`):
305
+ The number of channels in each head.
306
+ in_channels (`int`, defaults to `16`):
307
+ The number of channels in the input.
308
+ out_channels (`int`, *optional*, defaults to `16`):
309
+ The number of channels in the output.
310
+ flip_sin_to_cos (`bool`, defaults to `True`):
311
+ Whether to flip the sin to cos in the time embedding.
312
+ time_embed_dim (`int`, defaults to `512`):
313
+ Output dimension of timestep embeddings.
314
+ text_embed_dim (`int`, defaults to `4096`):
315
+ Input dimension of text embeddings from the text encoder.
316
+ num_layers (`int`, defaults to `30`):
317
+ The number of layers of Transformer blocks to use.
318
+ dropout (`float`, defaults to `0.0`):
319
+ The dropout probability to use.
320
+ attention_bias (`bool`, defaults to `True`):
321
+ Whether or not to use bias in the attention projection layers.
322
+ sample_width (`int`, defaults to `90`):
323
+ The width of the input latents.
324
+ sample_height (`int`, defaults to `60`):
325
+ The height of the input latents.
326
+ sample_frames (`int`, defaults to `49`):
327
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
328
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
329
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
330
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
331
+ patch_size (`int`, defaults to `2`):
332
+ The size of the patches to use in the patch embedding layer.
333
+ temporal_compression_ratio (`int`, defaults to `4`):
334
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
335
+ max_text_seq_length (`int`, defaults to `226`):
336
+ The maximum sequence length of the input text embeddings.
337
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
338
+ Activation function to use in feed-forward.
339
+ timestep_activation_fn (`str`, defaults to `"silu"`):
340
+ Activation function to use when generating the timestep embeddings.
341
+ norm_elementwise_affine (`bool`, defaults to `True`):
342
+ Whether or not to use elementwise affine in normalization layers.
343
+ norm_eps (`float`, defaults to `1e-5`):
344
+ The epsilon value to use in normalization layers.
345
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
346
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
347
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
348
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
349
+ """
350
+
351
+ _supports_gradient_checkpointing = True
352
+
353
+ @register_to_config
354
+ def __init__(
355
+ self,
356
+ num_attention_heads: int = 30,
357
+ attention_head_dim: int = 64,
358
+ in_channels: int = 16,
359
+ out_channels: Optional[int] = 16,
360
+ flip_sin_to_cos: bool = True,
361
+ freq_shift: int = 0,
362
+ time_embed_dim: int = 512,
363
+ text_embed_dim: int = 4096,
364
+ num_layers: int = 30,
365
+ dropout: float = 0.0,
366
+ attention_bias: bool = True,
367
+ sample_width: int = 90,
368
+ sample_height: int = 60,
369
+ sample_frames: int = 49,
370
+ patch_size: int = 2,
371
+ patch_size_t: Optional[int] = None,
372
+ temporal_compression_ratio: int = 4,
373
+ max_text_seq_length: int = 226,
374
+ activation_fn: str = "gelu-approximate",
375
+ timestep_activation_fn: str = "silu",
376
+ norm_elementwise_affine: bool = True,
377
+ norm_eps: float = 1e-5,
378
+ spatial_interpolation_scale: float = 1.875,
379
+ temporal_interpolation_scale: float = 1.0,
380
+ use_rotary_positional_embeddings: bool = False,
381
+ use_learned_positional_embeddings: bool = False,
382
+ patch_bias: bool = True,
383
+ add_noise_in_inpaint_model: bool = False,
384
+ ):
385
+ super().__init__()
386
+ inner_dim = num_attention_heads * attention_head_dim
387
+ self.patch_size_t = patch_size_t
388
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
389
+ raise ValueError(
390
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
391
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
392
+ "issue at https://github.com/huggingface/diffusers/issues."
393
+ )
394
+
395
+ # 1. Patch embedding
396
+ self.patch_embed = CogVideoXPatchEmbed(
397
+ patch_size=patch_size,
398
+ patch_size_t=patch_size_t,
399
+ in_channels=in_channels,
400
+ embed_dim=inner_dim,
401
+ text_embed_dim=text_embed_dim,
402
+ bias=patch_bias,
403
+ sample_width=sample_width,
404
+ sample_height=sample_height,
405
+ sample_frames=sample_frames,
406
+ temporal_compression_ratio=temporal_compression_ratio,
407
+ max_text_seq_length=max_text_seq_length,
408
+ spatial_interpolation_scale=spatial_interpolation_scale,
409
+ temporal_interpolation_scale=temporal_interpolation_scale,
410
+ use_positional_embeddings=not use_rotary_positional_embeddings,
411
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
412
+ )
413
+ self.embedding_dropout = nn.Dropout(dropout)
414
+
415
+ # 2. Time embeddings
416
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
417
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
418
+
419
+ # 3. Define spatio-temporal transformers blocks
420
+ self.transformer_blocks = nn.ModuleList(
421
+ [
422
+ CogVideoXBlock(
423
+ dim=inner_dim,
424
+ num_attention_heads=num_attention_heads,
425
+ attention_head_dim=attention_head_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ dropout=dropout,
428
+ activation_fn=activation_fn,
429
+ attention_bias=attention_bias,
430
+ norm_elementwise_affine=norm_elementwise_affine,
431
+ norm_eps=norm_eps,
432
+ )
433
+ for _ in range(num_layers)
434
+ ]
435
+ )
436
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
437
+
438
+ # 4. Output blocks
439
+ self.norm_out = AdaLayerNorm(
440
+ embedding_dim=time_embed_dim,
441
+ output_dim=2 * inner_dim,
442
+ norm_elementwise_affine=norm_elementwise_affine,
443
+ norm_eps=norm_eps,
444
+ chunk_dim=1,
445
+ )
446
+
447
+ if patch_size_t is None:
448
+ # For CogVideox 1.0
449
+ output_dim = patch_size * patch_size * out_channels
450
+ else:
451
+ # For CogVideoX 1.5
452
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
453
+
454
+ self.proj_out = nn.Linear(inner_dim, output_dim)
455
+
456
+ self.gradient_checkpointing = False
457
+ self.sp_world_size = 1
458
+ self.sp_world_rank = 0
459
+
460
+ def _set_gradient_checkpointing(self, module, value=False):
461
+ self.gradient_checkpointing = value
462
+
463
+ def enable_multi_gpus_inference(self,):
464
+ self.sp_world_size = get_sequence_parallel_world_size()
465
+ self.sp_world_rank = get_sequence_parallel_rank()
466
+ self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
467
+
468
+ @property
469
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
470
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
471
+ r"""
472
+ Returns:
473
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
474
+ indexed by its weight name.
475
+ """
476
+ # set recursively
477
+ processors = {}
478
+
479
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
480
+ if hasattr(module, "get_processor"):
481
+ processors[f"{name}.processor"] = module.get_processor()
482
+
483
+ for sub_name, child in module.named_children():
484
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
485
+
486
+ return processors
487
+
488
+ for name, module in self.named_children():
489
+ fn_recursive_add_processors(name, module, processors)
490
+
491
+ return processors
492
+
493
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
494
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
495
+ r"""
496
+ Sets the attention processor to use to compute attention.
497
+
498
+ Parameters:
499
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
500
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
501
+ for **all** `Attention` layers.
502
+
503
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
504
+ processor. This is strongly recommended when setting trainable attention processors.
505
+
506
+ """
507
+ count = len(self.attn_processors.keys())
508
+
509
+ if isinstance(processor, dict) and len(processor) != count:
510
+ raise ValueError(
511
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
512
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
513
+ )
514
+
515
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
516
+ if hasattr(module, "set_processor"):
517
+ if not isinstance(processor, dict):
518
+ module.set_processor(processor)
519
+ else:
520
+ module.set_processor(processor.pop(f"{name}.processor"))
521
+
522
+ for sub_name, child in module.named_children():
523
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
524
+
525
+ for name, module in self.named_children():
526
+ fn_recursive_attn_processor(name, module, processor)
527
+
528
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
529
+ def fuse_qkv_projections(self):
530
+ """
531
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
532
+ are fused. For cross-attention modules, key and value projection matrices are fused.
533
+
534
+ <Tip warning={true}>
535
+
536
+ This API is 🧪 experimental.
537
+
538
+ </Tip>
539
+ """
540
+ self.original_attn_processors = None
541
+
542
+ for _, attn_processor in self.attn_processors.items():
543
+ if "Added" in str(attn_processor.__class__.__name__):
544
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
545
+
546
+ self.original_attn_processors = self.attn_processors
547
+
548
+ for module in self.modules():
549
+ if isinstance(module, Attention):
550
+ module.fuse_projections(fuse=True)
551
+
552
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
553
+
554
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
555
+ def unfuse_qkv_projections(self):
556
+ """Disables the fused QKV projection if enabled.
557
+
558
+ <Tip warning={true}>
559
+
560
+ This API is 🧪 experimental.
561
+
562
+ </Tip>
563
+
564
+ """
565
+ if self.original_attn_processors is not None:
566
+ self.set_attn_processor(self.original_attn_processors)
567
+
568
+ def forward(
569
+ self,
570
+ hidden_states: torch.Tensor,
571
+ encoder_hidden_states: torch.Tensor,
572
+ timestep: Union[int, float, torch.LongTensor],
573
+ timestep_cond: Optional[torch.Tensor] = None,
574
+ inpaint_latents: Optional[torch.Tensor] = None,
575
+ control_latents: Optional[torch.Tensor] = None,
576
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
577
+ return_dict: bool = True,
578
+ ):
579
+ batch_size, num_frames, channels, height, width = hidden_states.shape
580
+ if num_frames == 1 and self.patch_size_t is not None:
581
+ hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
582
+ if inpaint_latents is not None:
583
+ inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
584
+ if control_latents is not None:
585
+ control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
586
+ local_num_frames = num_frames + 1
587
+ else:
588
+ local_num_frames = num_frames
589
+
590
+ # 1. Time embedding
591
+ timesteps = timestep
592
+ t_emb = self.time_proj(timesteps)
593
+
594
+ # timesteps does not contain any weights and will always return f32 tensors
595
+ # but time_embedding might actually be running in fp16. so we need to cast here.
596
+ # there might be better ways to encapsulate this.
597
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
598
+ emb = self.time_embedding(t_emb, timestep_cond)
599
+
600
+ # 2. Patch embedding
601
+ if inpaint_latents is not None:
602
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
603
+ if control_latents is not None:
604
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
605
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
606
+ hidden_states = self.embedding_dropout(hidden_states)
607
+
608
+ text_seq_length = encoder_hidden_states.shape[1]
609
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
610
+ hidden_states = hidden_states[:, text_seq_length:]
611
+
612
+ # Context Parallel
613
+ if self.sp_world_size > 1:
614
+ hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
615
+ if image_rotary_emb is not None:
616
+ image_rotary_emb = (
617
+ torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
618
+ torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
619
+ )
620
+
621
+ # 3. Transformer blocks
622
+ for i, block in enumerate(self.transformer_blocks):
623
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
624
+
625
+ def create_custom_forward(module):
626
+ def custom_forward(*inputs):
627
+ return module(*inputs)
628
+
629
+ return custom_forward
630
+
631
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
632
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
633
+ create_custom_forward(block),
634
+ hidden_states,
635
+ encoder_hidden_states,
636
+ emb,
637
+ image_rotary_emb,
638
+ **ckpt_kwargs,
639
+ )
640
+ else:
641
+ hidden_states, encoder_hidden_states = block(
642
+ hidden_states=hidden_states,
643
+ encoder_hidden_states=encoder_hidden_states,
644
+ temb=emb,
645
+ image_rotary_emb=image_rotary_emb,
646
+ )
647
+
648
+ if not self.config.use_rotary_positional_embeddings:
649
+ # CogVideoX-2B
650
+ hidden_states = self.norm_final(hidden_states)
651
+ else:
652
+ # CogVideoX-5B
653
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
654
+ hidden_states = self.norm_final(hidden_states)
655
+ hidden_states = hidden_states[:, text_seq_length:]
656
+
657
+ # 4. Final block
658
+ hidden_states = self.norm_out(hidden_states, temb=emb)
659
+ hidden_states = self.proj_out(hidden_states)
660
+
661
+ if self.sp_world_size > 1:
662
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
663
+
664
+ # 5. Unpatchify
665
+ p = self.config.patch_size
666
+ p_t = self.config.patch_size_t
667
+
668
+ if p_t is None:
669
+ output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
670
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
671
+ else:
672
+ output = hidden_states.reshape(
673
+ batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
674
+ )
675
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
676
+
677
+ if num_frames == 1:
678
+ output = output[:, :num_frames, :]
679
+
680
+ if not return_dict:
681
+ return (output,)
682
+ return Transformer2DModelOutput(sample=output)
683
+
684
+ @classmethod
685
+ def from_pretrained(
686
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
687
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16, use_vae_mask=False, stack_mask=False,
688
+ ):
689
+ if subfolder is not None:
690
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
691
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
692
+
693
+ config_file = os.path.join(pretrained_model_path, 'config.json')
694
+ if not os.path.isfile(config_file):
695
+ raise RuntimeError(f"{config_file} does not exist")
696
+ with open(config_file, "r") as f:
697
+ config = json.load(f)
698
+
699
+ if use_vae_mask:
700
+ print('[DEBUG] use vae to encode mask')
701
+ config['in_channels'] = 48
702
+ elif stack_mask:
703
+ print('[DEBUG] use stacking mask')
704
+ config['in_channels'] = 36
705
+
706
+ from diffusers.utils import WEIGHTS_NAME
707
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
708
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
709
+
710
+ if "dict_mapping" in transformer_additional_kwargs.keys():
711
+ for key in transformer_additional_kwargs["dict_mapping"]:
712
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
713
+
714
+ if low_cpu_mem_usage:
715
+ try:
716
+ import re
717
+
718
+ from diffusers.models.modeling_utils import \
719
+ load_model_dict_into_meta
720
+ from diffusers.utils import is_accelerate_available
721
+ if is_accelerate_available():
722
+ import accelerate
723
+
724
+ # Instantiate model with empty weights
725
+ with accelerate.init_empty_weights():
726
+ model = cls.from_config(config, **transformer_additional_kwargs)
727
+
728
+ param_device = "cpu"
729
+ if os.path.exists(model_file):
730
+ state_dict = torch.load(model_file, map_location="cpu")
731
+ elif os.path.exists(model_file_safetensors):
732
+ from safetensors.torch import load_file, safe_open
733
+ state_dict = load_file(model_file_safetensors)
734
+ else:
735
+ from safetensors.torch import load_file, safe_open
736
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
737
+ state_dict = {}
738
+ for _model_file_safetensors in model_files_safetensors:
739
+ _state_dict = load_file(_model_file_safetensors)
740
+ for key in _state_dict:
741
+ state_dict[key] = _state_dict[key]
742
+ model._convert_deprecated_attention_blocks(state_dict)
743
+ # move the params from meta device to cpu
744
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
745
+ if len(missing_keys) > 0:
746
+ raise ValueError(
747
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
748
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
749
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
750
+ " those weights or else make sure your checkpoint file is correct."
751
+ )
752
+
753
+ unexpected_keys = load_model_dict_into_meta(
754
+ model,
755
+ state_dict,
756
+ device=param_device,
757
+ dtype=torch_dtype,
758
+ model_name_or_path=pretrained_model_path,
759
+ )
760
+
761
+ if cls._keys_to_ignore_on_load_unexpected is not None:
762
+ for pat in cls._keys_to_ignore_on_load_unexpected:
763
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
764
+
765
+ if len(unexpected_keys) > 0:
766
+ print(
767
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
768
+ )
769
+ return model
770
+ except Exception as e:
771
+ print(
772
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
773
+ )
774
+
775
+ model = cls.from_config(config, **transformer_additional_kwargs)
776
+ if os.path.exists(model_file):
777
+ state_dict = torch.load(model_file, map_location="cpu")
778
+ elif os.path.exists(model_file_safetensors):
779
+ from safetensors.torch import load_file, safe_open
780
+ state_dict = load_file(model_file_safetensors)
781
+ else:
782
+ from safetensors.torch import load_file, safe_open
783
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
784
+ state_dict = {}
785
+ for _model_file_safetensors in model_files_safetensors:
786
+ _state_dict = load_file(_model_file_safetensors)
787
+ for key in _state_dict:
788
+ state_dict[key] = _state_dict[key]
789
+
790
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
791
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
792
+ if len(new_shape) == 5:
793
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
794
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
795
+ elif len(new_shape) == 2:
796
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
797
+ if use_vae_mask:
798
+ print('[DEBUG] patch_embed.proj.weight size does not match due to vae-encoded mask')
799
+ latent_ch = 16
800
+ feat_scale = 8
801
+ feat_dim = int(latent_ch * feat_scale)
802
+ old_total_dim = state_dict['patch_embed.proj.weight'].size(1)
803
+ new_total_dim = model.state_dict()['patch_embed.proj.weight'].size(1)
804
+ model.state_dict()['patch_embed.proj.weight'][:, :feat_dim] = state_dict['patch_embed.proj.weight'][:, :feat_dim]
805
+ model.state_dict()['patch_embed.proj.weight'][:, -feat_dim:] = state_dict['patch_embed.proj.weight'][:, -feat_dim:]
806
+ for i in range(feat_dim, new_total_dim - feat_dim, feat_scale):
807
+ model.state_dict()['patch_embed.proj.weight'][:, i:i+feat_scale] = state_dict['patch_embed.proj.weight'][:, feat_dim:-feat_dim]
808
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
809
+ else:
810
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
811
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
812
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
813
+ else:
814
+ model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
815
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
816
+ else:
817
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
818
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
819
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
820
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
821
+ else:
822
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
823
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
824
+
825
+ tmp_state_dict = {}
826
+ for key in state_dict:
827
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
828
+ tmp_state_dict[key] = state_dict[key]
829
+ else:
830
+ print(key, "Size don't match, skip")
831
+
832
+ state_dict = tmp_state_dict
833
+
834
+ m, u = model.load_state_dict(state_dict, strict=False)
835
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
836
+ print(m)
837
+
838
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
839
+ print(f"### All Parameters: {sum(params) / 1e6} M")
840
+
841
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
842
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
843
+
844
+ model = model.to(torch_dtype)
845
+ return model