[PyTorch][Flash Attn] Add fallback import for FA3 #2806
[PyTorch][Flash Attn] Add fallback import for FA3 #2806eattia-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
…erface.py is outside flash_attn_3 package Some FA3 installations (e.g. via pip) place flash_attn_interface.py directly under site-packages/ rather than inside flash_attn_3/. This causes a ModuleNotFoundError when importing from flash_attn_3.flash_attn_interface. Add a try/except ModuleNotFoundError fallback to import directly from flash_attn_interface when the subpackage import fails. Signed-off-by: Emmanuel Attia <eattia@nvidia.com>
Greptile SummaryThis PR adds a robust fallback import path for FA3 to support wheel-based installations that place Key changes in
Confidence Score: 4/5This PR is safe to merge; it handles both FA3 install layouts correctly, adds a visible warning on the fallback path, and raises a helpful error when neither layout is found. Both concerns from the previous review round are fully addressed: a warnings.warn is now emitted on the fallback path, and find_spec replaces try/except ModuleNotFoundError to avoid swallowing transitive import errors. The change is narrowly scoped to the import block, does not affect runtime behaviour when FA3 is absent or fully installed under flash_attn_3/, and the new explicit raise ModuleNotFoundError gives clearer diagnostics than the original implicit ImportError. No new logic issues were found. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Module load] --> B{importlib.metadata:\nflash-attn-3 installed?}
B -- PackageNotFoundError --> C[Set flash_attn_func_v3 = None\nflash_attn_varlen_func_v3 = None\nflash_attn_with_kvcache_v3 = None]
B -- Found --> D{find_spec\nflash_attn_3.flash_attn_interface\nnot None?}
D -- Yes --> E[Import from\nflash_attn_3.flash_attn_interface]
D -- No --> F{find_spec\nflash_attn_interface\nnot None?}
F -- Yes --> G[warnings.warn\nImport from\nflash_attn_interface]
F -- No --> H[raise ModuleNotFoundError\nwith helpful message]
E --> I[fa_utils.set_flash_attention_3_params]
G --> I
Reviews (2): Last reviewed commit: "[PyTorch][Flash Attn] Use find_spec for ..." | Re-trigger Greptile |
| except ModuleNotFoundError: | ||
| from flash_attn_interface import flash_attn_func as flash_attn_func_v3 | ||
| from flash_attn_interface import ( | ||
| flash_attn_varlen_func as flash_attn_varlen_func_v3, | ||
| ) | ||
| from flash_attn_interface import ( | ||
| flash_attn_with_kvcache as flash_attn_with_kvcache_v3, | ||
| ) | ||
| from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 | ||
| from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 |
There was a problem hiding this comment.
No diagnostic logging on fallback path
When flash_attn_3.flash_attn_interface is not found and the fallback to the flat flash_attn_interface import succeeds silently, there is no indication in the logs that the alternative layout was used. This makes it harder to diagnose import-related issues at runtime (e.g., someone upgrading FA3 and expecting the flash_attn_3 layout to work).
A warning log similar to the one used for the FA2 version mismatch would be helpful:
except ModuleNotFoundError:
import warnings
warnings.warn(
"Could not import from flash_attn_3.flash_attn_interface; "
"falling back to top-level flash_attn_interface. "
"Consider installing FA3 via a package that places flash_attn_interface "
"inside the flash_attn_3 namespace.",
ImportWarning,
stacklevel=2,
)
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
...Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| try: | ||
| from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 | ||
| from flash_attn_3.flash_attn_interface import ( | ||
| flash_attn_varlen_func as flash_attn_varlen_func_v3, | ||
| ) | ||
| from flash_attn_3.flash_attn_interface import ( | ||
| flash_attn_with_kvcache as flash_attn_with_kvcache_v3, | ||
| ) | ||
| from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 | ||
| from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 | ||
| except ModuleNotFoundError: | ||
| from flash_attn_interface import flash_attn_func as flash_attn_func_v3 | ||
| from flash_attn_interface import ( | ||
| flash_attn_varlen_func as flash_attn_varlen_func_v3, | ||
| ) | ||
| from flash_attn_interface import ( | ||
| flash_attn_with_kvcache as flash_attn_with_kvcache_v3, | ||
| ) | ||
| from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 | ||
| from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 |
There was a problem hiding this comment.
ModuleNotFoundError from transitive imports silently swallowed
ModuleNotFoundError is a subclass of ImportError and can be raised not only when flash_attn_3.flash_attn_interface itself is missing, but also when any transitive dependency of that module is missing at import time (e.g., if flash_attn_3.flash_attn_interface internally does from flash_attn_3 import _C and _C is missing).
In that scenario the except ModuleNotFoundError would trigger the flat-layout fallback, even though flash_attn_3.flash_attn_interface actually exists. The fallback might then succeed with a stale or different build, masking a real installation problem.
A narrower guard would avoid this:
try:
from flash_attn_3 import flash_attn_interface as _fa3_iface
except ModuleNotFoundError:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_v3
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
else:
flash_attn_func_v3 = _fa3_iface.flash_attn_func
flash_attn_varlen_func_v3 = _fa3_iface.flash_attn_varlen_func
flash_attn_with_kvcache_v3 = _fa3_iface.flash_attn_with_kvcache
_flash_attn_fwd_v3 = _fa3_iface._flash_attn_forward
_flash_attn_bwd_v3 = _fa3_iface._flash_attn_backwardAlternatively, checking importlib.util.find_spec("flash_attn_3.flash_attn_interface") before importing would make the intent explicit without relying on exception-driven control flow.
… warning Address review feedback: - Use importlib.util.find_spec() for explicit module checking instead of exception-driven control flow, avoiding masking real import errors - Add a warning when the flat layout fallback is used for easier debugging - Raise a clear error if flash_attn_interface is not found in either location Signed-off-by: Emmanuel Attia <eattia@nvidia.com>
Add fallback import for FA3 when flash_attn_interface.py is outside flash_attn_3 package
Description
Some FA3 installations (e.g. via wheel built from the https://github.com/Dao-AILab/flash-attention repo) place
flash_attn_interface.py directly under site-packages/ rather than inside the flash_attn_3/ folder:
site-packages/
├── flash_attn_interface.py
└── flash_attn_3/
├── init.py
└── _C.abi3.so
This causes a ModuleNotFoundError when TE tries to import from flash_attn_3.flash_attn_interface, even though FA3 is
correctly installed and detected via importlib.metadata.
This PR adds a try/except ModuleNotFoundError fallback so both layouts are supported:
Type of change
Changes
Please list the changes introduced in this PR:
import ... when from flash_attn_3.flash_attn_interface import ... fails
Checklist: