smartdigitalnetworks commited on
Commit
78e22d2
·
verified ·
1 Parent(s): 1bf5c79

Update packages/ltx-core/src/ltx_core/model/transformer/attention.py

Browse files
packages/ltx-core/src/ltx_core/model/transformer/attention.py CHANGED
@@ -12,7 +12,10 @@ try:
12
  except ImportError:
13
  memory_efficient_attention = None
14
 
15
- import flash_attn_interface
 
 
 
16
 
17
  class AttentionCallable(Protocol):
18
  def __call__(
@@ -120,7 +123,7 @@ class AttentionFunction(Enum):
120
  def __call__(
121
  self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
122
  ) -> torch.Tensor:
123
- if mask is None:
124
  return FlashAttention3()(q, k, v, heads, mask)
125
  else:
126
  return (
 
12
  except ImportError:
13
  memory_efficient_attention = None
14
 
15
+ try:
16
+ import flash_attn_interface
17
+ except ImportError:
18
+ flash_attn_interface = None
19
 
20
  class AttentionCallable(Protocol):
21
  def __call__(
 
123
  def __call__(
124
  self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
125
  ) -> torch.Tensor:
126
+ if mask is None and flash_attn_interface is not None:
127
  return FlashAttention3()(q, k, v, heads, mask)
128
  else:
129
  return (