[PyT] Fix FSDP2 memory leaks for FP8 weight workspaces and transpose caches#2805
[PyT] Fix FSDP2 memory leaks for FP8 weight workspaces and transpose caches#2805pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Add tests that demonstrate two known memory issues with FSDP2 + FP8: - Issue NVIDIA#2681: FP8 weight copies created during te.autocast() forward pass accumulate across layers instead of being freed between layers, defeating FSDP2's memory efficiency. Detected by comparing per-layer forward memory increments against a bf16 baseline using layer hooks. - Issue NVIDIA#2717: Transpose cache tensors (_create_transpose) allocated during backward persist until the next forward pass instead of being freed after backward completes. Detected by comparing the backward memory delta (post_bwd - post_fwd) against a bf16 baseline. New tests: - test_bf16_no_excess_forward_memory: control, validates per-layer measurement - test_bf16_no_excess_backward_memory: control, validates backward delta comparison - test_fp8_temp_accumulation_across_layers: xfail, detects NVIDIA#2681 - test_transpose_cache_retained_after_backward: xfail, detects NVIDIA#2717 All parametrized over 5 FP8 recipes x {no_quant_init, quant_init}. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… constant - Fix standalone runner to not pass recipe/quantized_model_init args to bf16 control tests (which take no arguments) - Fix stale comment referencing 4-layer model (now 8 layers) - Remove unused MEASURED_STEPS constant Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…caches Fix memory leaks where FP8 quantized weight copies and transpose caches accumulate during FSDP2 training, defeating FSDP2's per-layer memory savings (Issues NVIDIA#2681, NVIDIA#2717). Changes to layernorm_mlp.py, layernorm_linear.py, linear.py: - Detect FSDP2 via _get_module_fsdp_state; guard to tensor-scaling and MXFP8 quantizers whose backward re-creation is validated. - Skip columnwise/transpose creation on weight quantizers during forward so FP8 caches don't accumulate across layers. - Disable workspace caching (cache_name=None) under FSDP2 to prevent _fp8_workspaces from retaining per-layer copies. - Don't save separate FP8 workspace copies for backward; re-create from the FSDP2 all-gathered weight in backward instead. - Clear Float8TensorStorage._transpose after backward dgrad GEMMs to prevent transpose data persisting on FSDP2's reusable buffers. Test changes (run_fsdp2_mem_leak.py): - Remove xfail markers for fixed recipes (DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling). - Add targeted xfail for Float8BlockScaling/NVFP4BlockScaling whose blockwise storage classes have separate internal caching. - Increase backward test tolerance to 1 MiB to account for temporary workspace re-creation during backward. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Greptile SummaryThis PR fixes two FSDP2 memory leaks in TransformerEngine's FP8 training path — quantized weight workspaces accumulating across layers during forward (Issue #2681) and Confidence Score: 4/5Safe to merge with one minor follow-up: the bundled pre-existing bug fix in layernorm_mlp.py should be acknowledged; all other issues are style/naming. The core FSDP2 memory-leak fixes are logically sound and consistently applied across all three module files. The new test suite directly validates the fixed scenarios with appropriate tolerances and xfail markers for unresolved blockwise cases. Backward reconstruction of the FP8 workspace is correct: origin_weight is always saved (never None) and the quantizer is always non-None when wt_save=None can be set. The P1 comment covers an incidental fix to a pre-existing bug (wrong isinstance target) that changes non-FSDP2 behavior but is covered by passing functional tests. Remaining comments are P2 style/naming issues. transformer_engine/pytorch/module/layernorm_mlp.py — contains the bundled isinstance bug fix and the ctx.is_fsdp2 naming confusion; warrants a close read on the non-FSDP2 FC1 dgrad path. Important Files Changed
Sequence DiagramsequenceDiagram
participant FSDP2 as FSDP2 Runtime
participant FWD as Forward Pass (_Linear / _LayerNormLinear / _LayerNormMLP)
participant CTX as Autograd Context
participant BWD as Backward Pass
FSDP2->>FWD: All-gather weight (BF16/FP8 param)
FWD->>FWD: detect FSDP2 via _get_module_fsdp_state()
Note over FWD: is_fsdp2=True → columnwise=False,<br/>cache_name=None, wt_save=None
FWD->>FWD: get_weight_workspace() → weightmat (FP8, rowwise only)
FWD->>CTX: save(inputmat, wt_save=None, origin_weight, ...)
FWD-->>FSDP2: Reshard weight (buffer freed)
FSDP2->>BWD: All-gather weight (BF16/FP8 param)
BWD->>CTX: restore_from_func_ctx() → weight_fp8=None
BWD->>BWD: weight_fp8 is None → re-quantize(origin_weight, columnwise=True)
BWD->>BWD: dgrad GEMM (weight_fp8 transposed)
BWD->>BWD: weight._transpose = None (clear cache, Issue #2717)
BWD-->>FSDP2: Reshard weight (buffer freed, no dangling _transpose)
|
| ctx.fc2_weight_requires_grad = fc2_weight.requires_grad | ||
| ctx.fc1_weight = fc1_weight | ||
| ctx.fc2_weight = fc2_weight | ||
| ctx.is_fsdp2 = fsdp2_skip_columnwise |
There was a problem hiding this comment.
ctx.is_fsdp2 stores fsdp2_skip_columnwise, not the raw is_fsdp2 flag
fsdp2_skip_columnwise is is_fsdp2 and not is_recomputation and _is_safe_for_fsdp2, which is more restrictive than is_fsdp2. Storing it as ctx.is_fsdp2 could mislead a reader into thinking it simply encodes "this is an FSDP2 run." The backward's transpose-clearing guard (getattr(ctx, "is_fsdp2", False)) happens to be correct because you only want to clear during a real (non-recompute) backward, but the naming makes the invariant implicit.
Consider renaming to ctx.fsdp2_skip_columnwise (matching the local variable name) to make the semantics self-documenting.
| ctx.is_fsdp2 = fsdp2_skip_columnwise | |
| ctx.fsdp2_skip_columnwise = fsdp2_skip_columnwise |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if ctx.fc1_weight_quantizer is not None and isinstance( | ||
| ctx.fc1_weight_quantizer, QuantizedTensorStorage | ||
| ctx.fc1_weight, QuantizedTensorStorage | ||
| ): |
There was a problem hiding this comment.
Incidental bug fix:
isinstance target changed from quantizer to weight
The pre-existing code read:
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensorStorage # always False
):
ctx.fc1_weight.update_usage(columnwise_usage=True)Because a Quantizer is never a QuantizedTensorStorage, this condition was always False and update_usage was never called on the FC1 weight in this branch. This PR silently fixes that by checking ctx.fc1_weight instead — which is correct — but it also changes backward behavior for non-FSDP2 paths (e.g., fp8_init with primary-FP8 weights). The passing functional tests provide coverage, but a comment or separate commit noting this as a pre-existing bug fix would keep the diff more reviewable.
Summary
Fixes memory leaks where FP8 quantized weight workspaces and transpose caches accumulate during FSDP2 training, defeating FSDP2's per-layer memory savings.
_create_transposetensors allocated during backward persisted on FSDP2's reusable all-gather buffers until the next forward pass.Approach
When FSDP2 is detected (via
_get_module_fsdp_state), the fix applies tolayernorm_mlp.py,layernorm_linear.py, andlinear.py:cache_name=None) — prevents_fp8_workspacesfrom retaining per-layer copies_transposeafter backward dgrad GEMMs — prevents transpose data persisting on reusable buffersGuarded to Float8Quantizer, Float8CurrentScalingQuantizer, and MXFP8Quantizer. Blockwise quantizers (Float8BlockScaling, NVFP4BlockScaling) have separate internal caching not yet addressed.
Test changes
Updates the memory leak detection tests from #2803:
xfailmarkers for fixed recipesxfailfor blockwise recipesTest plan
test_torch_fsdp2.py(4 passed)🤖 Generated with Claude Code
Closes #2681
Closes #2717