Skip to content

[JAX] Warmup FFIs with "initialize" stage#2800

Draft
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/warmup-xla-ffis
Draft

[JAX] Warmup FFIs with "initialize" stage#2800
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/warmup-xla-ffis

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 25, 2026 21:19
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 25, 2026

Greptile Summary

This PR extends the JAX backend of TransformerEngine by adding XLA FFI "initialize" stage handlers for the softmax, quantization, attention, GEMM, and router custom-call operations, completing warmup coverage that already existed for norm and activation layers.

Key changes:

  • Each new *InitializeFFI function is a thin wrapper that calls wrapInStreamCapture around the existing *FFI execute function. wrapInStreamCapture performs CUDA graph capture (cudaStreamBeginCapture / cudaStreamEndCapture) and immediately destroys the captured graph — this triggers JIT kernel compilation during the XLA "initialize" phase, so the first real execution does not incur that overhead.
  • All new *InitializeHandler symbols are declared in extensions.h and registered in pybind.cpp under the "initialize" key of their respective pybind11::dict entries.
  • GroupedQuantizeFFI (and the legacy GroupedGemm / GroupedGemmD2HGroupSizes variants) are intentionally excluded from initialize wrappers because they call cudaMemcpyAsync (D2H) followed by cudaStreamSynchronize, both of which are incompatible with CUDA stream capture mode.
  • The implementation is consistent in style and scope with the pre-existing initialize handlers for norm, activation, and GEMM operations.

Confidence Score: 5/5

  • Safe to merge — changes are purely additive, follow an established pattern, and introduce no behavioral differences to the execute path.
  • All new initialize handlers are mechanical wrappers around existing, battle-tested execute functions. The exclusion of GroupedQuantize and legacy GroupedGemm variants from the initialize stage is correct because those functions contain D2H memcpy + stream-synchronize calls that would fail inside CUDA graph capture. The argument/return-type bindings of every new handler match their corresponding execute handler exactly. No existing code is modified.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions.h Adds XLA_FFI_DECLARE_HANDLER_SYMBOL forward declarations for all new *InitializeHandler symbols across quantization, softmax, attention, and router operations.
transformer_engine/jax/csrc/extensions/pybind.cpp Updates Registrations() to wire all new *InitializeHandler symbols into the initialize key of their respective FFI dict entries; te_grouped_quantize_ffi is intentionally left without an initialize handler.
transformer_engine/jax/csrc/extensions/quantization.cpp Adds DBiasQuantizeInitializeFFI/Handler and DequantizeInitializeFFI/Handler; both correctly wrap their execute counterparts via wrapInStreamCapture. GroupedQuantize is omitted because it performs a D2H memcpy + stream synchronize that is incompatible with CUDA graph capture.
transformer_engine/jax/csrc/extensions/softmax.cpp Adds initialize handlers for all 6 softmax variants (3 forward + 3 backward); ScaledMaskedSoftmaxBackwardInitializeHandler correctly reuses ScaledSoftmaxBackwardInitializeFFI, matching the existing execute handler behavior.
transformer_engine/jax/csrc/extensions/attention.cpp Adds FusedAttnForwardInitializeFFI/Handler and FusedAttnBackwardInitializeFFI/Handler; both use RemainingArgs + Attrs to mirror the execute handler bindings exactly.
transformer_engine/jax/csrc/extensions/router.cpp Adds initialize handlers for all 4 router operations (TopK forward/backward, MoEAuxLoss forward/backward); pattern is consistent with the rest of the PR.
transformer_engine/jax/csrc/extensions/gemm.cpp Adds GemmInitializeFFI/Handler, GemmV2InitializeFFI/Handler, and GroupedGemmV2InitializeFFI/Handler via wrapInStreamCapture; grouped GEMM variants that use D2H memcpy are correctly excluded.

Sequence Diagram

sequenceDiagram
    participant XLA as XLA/JAX Runtime
    participant Init as InitializeFFI (new)
    participant Wrap as wrapInStreamCapture
    participant CUDA as CUDA Driver
    participant Exec as FFI execute (existing)

    XLA->>Init: call during initialize stage
    Init->>Wrap: forward all args
    Wrap->>CUDA: cudaStreamBeginCapture
    Wrap->>Exec: call original FFI function
    Note over Exec,CUDA: kernels captured and JIT-compiled
    Exec-->>Wrap: Error_Type
    Wrap->>CUDA: cudaStreamEndCapture
    Wrap->>CUDA: cudaGraphDestroy
    Note over Wrap: warmup complete
    Wrap-->>Init: Error_Type
    Init-->>XLA: return

    XLA->>Exec: call during execute stage
    Exec->>CUDA: launch pre-compiled kernels
    CUDA-->>Exec: done
    Exec-->>XLA: return
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant