Spaces:
Sleeping
Sleeping
Update packages/ltx-core/src/ltx_core/model/transformer/attention.py
Browse files
packages/ltx-core/src/ltx_core/model/transformer/attention.py
CHANGED
|
@@ -12,7 +12,10 @@ try:
|
|
| 12 |
except ImportError:
|
| 13 |
memory_efficient_attention = None
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class AttentionCallable(Protocol):
|
| 18 |
def __call__(
|
|
@@ -120,7 +123,7 @@ class AttentionFunction(Enum):
|
|
| 120 |
def __call__(
|
| 121 |
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 122 |
) -> torch.Tensor:
|
| 123 |
-
if mask is None:
|
| 124 |
return FlashAttention3()(q, k, v, heads, mask)
|
| 125 |
else:
|
| 126 |
return (
|
|
|
|
| 12 |
except ImportError:
|
| 13 |
memory_efficient_attention = None
|
| 14 |
|
| 15 |
+
try:
|
| 16 |
+
import flash_attn_interface
|
| 17 |
+
except ImportError:
|
| 18 |
+
flash_attn_interface = None
|
| 19 |
|
| 20 |
class AttentionCallable(Protocol):
|
| 21 |
def __call__(
|
|
|
|
| 123 |
def __call__(
|
| 124 |
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 125 |
) -> torch.Tensor:
|
| 126 |
+
if mask is None and flash_attn_interface is not None:
|
| 127 |
return FlashAttention3()(q, k, v, heads, mask)
|
| 128 |
else:
|
| 129 |
return (
|