kashif HF Staff commited on
Commit
f4a945b
·
verified ·
1 Parent(s): ce6a36f

fix: remove deprecated is_torch_fx_available for transformers v5 compat

Browse files
Files changed (1) hide show
  1. modeling_llada2_moe.py +0 -8
modeling_llada2_moe.py CHANGED
@@ -52,7 +52,6 @@ from transformers.utils import (
52
  logging,
53
  replace_return_docstrings,
54
  )
55
- from transformers.utils.import_utils import is_torch_fx_available
56
  from .configuration_llada2_moe import LLaDA2MoeConfig
57
  from transformers.generation.utils import GenerationMixin
58
 
@@ -62,13 +61,6 @@ if is_flash_attn_2_available():
62
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
63
 
64
 
65
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
66
- # It means that the function will not be traced through and simply appear as a node in the graph.
67
- if is_torch_fx_available():
68
- if not is_torch_greater_or_equal_than_1_13:
69
- import torch.fx
70
-
71
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
72
 
73
 
74
  logger = logging.get_logger(__name__)
 
52
  logging,
53
  replace_return_docstrings,
54
  )
 
55
  from .configuration_llada2_moe import LLaDA2MoeConfig
56
  from transformers.generation.utils import GenerationMixin
57
 
 
61
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
 
63
 
 
 
 
 
 
 
 
64
 
65
 
66
  logger = logging.get_logger(__name__)