Skip to content

[PyTorch][Flash Attn] Add fallback import for FA3 #2806

Open
eattia-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
eattia-nvidia:fix_fa3_import_fallback
Open

[PyTorch][Flash Attn] Add fallback import for FA3 #2806
eattia-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
eattia-nvidia:fix_fa3_import_fallback

Conversation

@eattia-nvidia
Copy link

Add fallback import for FA3 when flash_attn_interface.py is outside flash_attn_3 package

Description

Some FA3 installations (e.g. via wheel built from the https://github.com/Dao-AILab/flash-attention repo) place
flash_attn_interface.py directly under site-packages/ rather than inside the flash_attn_3/ folder:

site-packages/
├── flash_attn_interface.py
└── flash_attn_3/
├── init.py
└── _C.abi3.so

This causes a ModuleNotFoundError when TE tries to import from flash_attn_3.flash_attn_interface, even though FA3 is
correctly installed and detected via importlib.metadata.

This PR adds a try/except ModuleNotFoundError fallback so both layouts are supported:

  1. flash_attn_interface.py inside flash_attn_3/ (existing behavior)
  2. flash_attn_interface.py directly in site-packages/

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add try/except ModuleNotFoundError around FA3 imports in backends.py to fall back to from flash_attn_interface
    import ... when from flash_attn_3.flash_attn_interface import ... fails

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…erface.py is outside flash_attn_3 package

Some FA3 installations (e.g. via pip) place flash_attn_interface.py
directly under site-packages/ rather than inside flash_attn_3/.
This causes a ModuleNotFoundError when importing from
flash_attn_3.flash_attn_interface.

Add a try/except ModuleNotFoundError fallback to import directly
from flash_attn_interface when the subpackage import fails.

Signed-off-by: Emmanuel Attia <eattia@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 26, 2026

Greptile Summary

This PR adds a robust fallback import path for FA3 to support wheel-based installations that place flash_attn_interface.py directly under site-packages/ rather than inside the flash_attn_3/ sub-package. It replaces the previous unconditional from flash_attn_3.flash_attn_interface import … with a two-stage importlib.util.find_spec check, and addresses both concerns raised in the previous review round (silent fallback and transitive-import masking).

Key changes in backends.py:

  • Adds import importlib.util to the module header.
  • Uses importlib.util.find_spec("flash_attn_3.flash_attn_interface") as the primary check; if the submodule exists it is imported as before.
  • If the submodule is absent, falls back to find_spec("flash_attn_interface") (the flat layout) and imports from there, emitting a warnings.warn so operators can see which path was taken.
  • If neither spec is found, raises an informative ModuleNotFoundError with a message explaining both expected locations.
  • Both previous review concerns are resolved: diagnostic logging is now present on the fallback path, and using find_spec instead of try/except ModuleNotFoundError prevents transitive import errors inside flash_attn_3.flash_attn_interface from being silently swallowed.

Confidence Score: 4/5

This PR is safe to merge; it handles both FA3 install layouts correctly, adds a visible warning on the fallback path, and raises a helpful error when neither layout is found.

Both concerns from the previous review round are fully addressed: a warnings.warn is now emitted on the fallback path, and find_spec replaces try/except ModuleNotFoundError to avoid swallowing transitive import errors. The change is narrowly scoped to the import block, does not affect runtime behaviour when FA3 is absent or fully installed under flash_attn_3/, and the new explicit raise ModuleNotFoundError gives clearer diagnostics than the original implicit ImportError. No new logic issues were found.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Replaces direct FA3 imports with find_spec-based detection plus a fallback to the flat flash_attn_interface layout; adds a warning and a helpful raise on failure.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Module load] --> B{importlib.metadata:\nflash-attn-3 installed?}
    B -- PackageNotFoundError --> C[Set flash_attn_func_v3 = None\nflash_attn_varlen_func_v3 = None\nflash_attn_with_kvcache_v3 = None]
    B -- Found --> D{find_spec\nflash_attn_3.flash_attn_interface\nnot None?}
    D -- Yes --> E[Import from\nflash_attn_3.flash_attn_interface]
    D -- No --> F{find_spec\nflash_attn_interface\nnot None?}
    F -- Yes --> G[warnings.warn\nImport from\nflash_attn_interface]
    F -- No --> H[raise ModuleNotFoundError\nwith helpful message]
    E --> I[fa_utils.set_flash_attention_3_params]
    G --> I
Loading

Reviews (2): Last reviewed commit: "[PyTorch][Flash Attn] Use find_spec for ..." | Re-trigger Greptile

Comment on lines +151 to +160
except ModuleNotFoundError:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No diagnostic logging on fallback path

When flash_attn_3.flash_attn_interface is not found and the fallback to the flat flash_attn_interface import succeeds silently, there is no indication in the logs that the alternative layout was used. This makes it harder to diagnose import-related issues at runtime (e.g., someone upgrading FA3 and expecting the flash_attn_3 layout to work).

A warning log similar to the one used for the FA2 version mismatch would be helpful:

    except ModuleNotFoundError:
        import warnings
        warnings.warn(
            "Could not import from flash_attn_3.flash_attn_interface; "
            "falling back to top-level flash_attn_interface. "
            "Consider installing FA3 via a package that places flash_attn_interface "
            "inside the flash_attn_3 namespace.",
            ImportWarning,
            stacklevel=2,
        )
        from flash_attn_interface import flash_attn_func as flash_attn_func_v3
        ...

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +141 to +160
try:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
except ModuleNotFoundError:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 ModuleNotFoundError from transitive imports silently swallowed

ModuleNotFoundError is a subclass of ImportError and can be raised not only when flash_attn_3.flash_attn_interface itself is missing, but also when any transitive dependency of that module is missing at import time (e.g., if flash_attn_3.flash_attn_interface internally does from flash_attn_3 import _C and _C is missing).

In that scenario the except ModuleNotFoundError would trigger the flat-layout fallback, even though flash_attn_3.flash_attn_interface actually exists. The fallback might then succeed with a stale or different build, masking a real installation problem.

A narrower guard would avoid this:

    try:
        from flash_attn_3 import flash_attn_interface as _fa3_iface
    except ModuleNotFoundError:
        from flash_attn_interface import flash_attn_func as flash_attn_func_v3
        from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
        from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_v3
        from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
        from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
    else:
        flash_attn_func_v3 = _fa3_iface.flash_attn_func
        flash_attn_varlen_func_v3 = _fa3_iface.flash_attn_varlen_func
        flash_attn_with_kvcache_v3 = _fa3_iface.flash_attn_with_kvcache
        _flash_attn_fwd_v3 = _fa3_iface._flash_attn_forward
        _flash_attn_bwd_v3 = _fa3_iface._flash_attn_backward

Alternatively, checking importlib.util.find_spec("flash_attn_3.flash_attn_interface") before importing would make the intent explicit without relying on exception-driven control flow.

… warning

Address review feedback:
- Use importlib.util.find_spec() for explicit module checking instead of
  exception-driven control flow, avoiding masking real import errors
- Add a warning when the flat layout fallback is used for easier debugging
- Raise a clear error if flash_attn_interface is not found in either location

Signed-off-by: Emmanuel Attia <eattia@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant