Skip to content

If model parameters are DTensors, optimizer states should also be DTensors.#2795

Open
cspades wants to merge 12 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp
Open

If model parameters are DTensors, optimizer states should also be DTensors.#2795
cspades wants to merge 12 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Mar 24, 2026

Description

  • There is a bug where if the model parameters (either FSDP2 distributed parameters or Megatron-FSDP main weight DTensors) are DTensors, then FusedAdam's optimizer state is converted into a non-distributed Tensor, which is loaded as a global / un-sharded state dictionary by Torch DCP. We wrap the optimizer state as a DTensor matching the distribution characteristics of the original DTensor parameter the state is associated with.

Fixes a bug introduced by the new DTensor(QuantizedTensor) (FSDP2-only) use case introduced in #2698 (as Megatron-FSDP just uses DTensor(Float32) for the distributed optimizer state).

Testing

  • TE CI/CD
TE_PATH=/workspace/TransformerEngine ./qa/L1_pytorch_distributed_unittest/test.sh
  • Megatron-LM + --use-precision-aware-optimizer
# TE@00ba0b493c27f32e2f210b0022132c50da78dac7 (Llama 8B + Precision-Aware Optimizer + FP8Blockwise + TP2 + GB300)
[2026-03-25 15:18:07.588704] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 9614.0 | throughput per GPU (TFLOP/s/GPU): 1403.6 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.131176E+00 | loss scale: 1.0 | grad norm: 5.337 | number of skipped iterations:   0 | number of nan iterations:   0 |

# This PR (Llama 8B Precision-Aware Optimizer + FP8Blockwise + TP2 + GB300)
[2026-03-25 14:58:55.856189] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 9588.0 | throughput per GPU (TFLOP/s/GPU): 1407.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.131045E+00 | loss scale: 1.0 | grad norm: 5.336 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • To reproduce the error motivating this PR, use the broken FusedAdam code before this PR/commit and run:
torchrun --nproc-per-node 4 -m pytest tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -s -k "dcp_resharding_save" && torchrun --nproc-per-node 2 -m pytest tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -s -k "dcp_resharding_load"

E               raise ValueError(
E           ValueError: Size mismatch between saved torch.Size([64]) and current: torch.Size([128]) for optimizer.state.0.exp_avg
E           Traceback (most recent call last): (RANK 1)
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 193, in reduce_scatter
E               local_data = map_fun()
E                            ^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/logger.py", line 90, in wrapper
E               result = func(*args, **kwargs)
E                        ^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 269, in local_step
E               local_plan = planner.create_local_plan()
E                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/default_planner.py", line 352, in create_local_plan
E               return create_default_local_load_plan(
E                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/default_planner.py", line 485, in create_default_local_load_plan
E               raise ValueError(
E           ValueError: Size mismatch between saved torch.Size([64]) and current: torch.Size([128]) for optimizer.state.0.exp_avg

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cspades and others added 3 commits March 24, 2026 09:48
@cspades cspades marked this pull request as ready for review March 24, 2026 17:42
@cspades
Copy link
Member Author

cspades commented Mar 24, 2026

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR fixes a bug where FusedAdam optimizer states were allocated as plain (non-distributed) tensors even when the associated model parameters were DTensors (FSDP2 sharded or Megatron-FSDP main weights). Because Torch DCP treats non-distributed tensors as global/un-sharded, this caused incorrect checkpoint save/load behaviour when using a distributed optimizer with DTensor params. The fix wraps each freshly initialized optimizer state in DTensor.from_local (using the parent param's device mesh, placements, and global shape) immediately after allocation, and consistently unwraps to the local shard whenever CUDA kernels or scaling operations are invoked.

Key changes:

  • _initialize_state: after creating the raw (or FP8-quantized) state tensor, wraps it in DTensor.from_local when param is a DTensor. Also fixes the pre-existing shape bug for FP8 states: quantizer.make_empty now uses data.shape (local shard shape) instead of param.shape (global shape).
  • get_unscaled_state / set_scaled_state: unwrap DTensor to ._local_tensor before any scaling or dtype assertion, making them DTensor-aware.
  • state_dict: after calling get_unscaled_state, re-wraps the returned FP32 local tensor as a DTensor so DCP can checkpoint and reshard it correctly.
  • load_state_dict: unwraps incoming DTensor states (which DCP may have resharded to the new topology) before calling set_scaled_state.
  • step: adds explicit DTensor parity assertions (param ↔ grad ↔ each optimizer state) and unwraps states to local tensors before appending to the CUDA kernel lists.
  • New test test_fsdp2_fused_adam_dcp_resharding: two-phase DCP resharding test (train + save with 4 ranks, reload with 2 ranks) verifying that the DTensor optimizer states are correctly resharded and the model output is bitwise-identical after loading.

Confidence Score: 4/5

  • The core fix is correct and all previously raised concerns (param.shape vs data.shape for FP8, DTensor check consistency in set_scaled_state) are addressed; the Megatron-LM + precision-aware optimizer validation is still TBA per the PR description.
  • The DTensor-wrapping logic is consistent across initialization, step, state_dict, and load_state_dict. The FP8 shape-mismatch regression is fixed (data.shape instead of param.shape). DTensor parity assertions in step() catch mismatches early. The new two-phase DCP resharding test is well-designed with appropriate xfail guards and fixed seeds. One remaining gap is that the Megatron-LM end-to-end test is listed as TBA in the description, so real-world validation beyond the TE unit tests is pending.
  • transformer_engine/pytorch/optimizers/fused_adam.py — the load_state_dict path deserves close scrutiny in code review since it interleaves super().load_state_dict() (for id mapping) with a full manual re-initialization loop that discards the superclass state.

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Core bug fix: wraps optimizer states as DTensors when params are DTensors. Correctly unwraps to local tensors for CUDA kernels and re-wraps for DCP. Fixes the param.shapedata.shape regression for FP8 states. Adds DTensor parity assertions in step().
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Adds test_dcp_resharding_save and test_dcp_resharding_load as two-phase tests for cross-topology DCP resharding (DP4 → DP2). Appropriate xfail guards and fixed-seed input for deterministic comparison.
tests/pytorch/distributed/test_torch_fsdp2.py Adds test_fsdp2_fused_adam_dcp_resharding which orchestrates two sequential torchrun invocations (4 ranks then 2 ranks) to validate cross-topology checkpoint resharding via DCP.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A{param is DTensor?} -->|Yes| B[Extract local_param via _local_tensor]
    A -->|No| C[local_param = param]
    B --> D[Dequantize if QuantizedTensor]
    C --> D
    D --> E[Allocate data with target dtype and local shape]
    E --> F{dtype == uint8?}
    F -->|Yes| G[Float8Quantizer.make_empty using data.shape]
    G --> H[quantize_ into Float8Tensor state]
    F -->|No| I[state = data plain tensor]
    H --> J{param is DTensor?}
    I --> J
    J -->|Yes| K[DTensor.from_local with global shape and placements]
    J -->|No| L[Store plain tensor in self.state]
    K --> L

    L --> M[step - unwrap DTensor to local tensor for CUDA kernel]
    L --> N[state_dict - unscale then re-wrap as DTensor for DCP]
    N --> O[load_state_dict - DCP provides resharded DTensor]
    O --> P[Unwrap local tensor, call set_scaled_state to rescale and store]
Loading

Reviews (7): Last reviewed commit: "Merge branch 'main' into cye/fused-adam-..." | Re-trigger Greptile

@vthumbe1503
Copy link
Collaborator

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

@cspades
Copy link
Member Author

cspades commented Mar 24, 2026

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?

  • Add fused_adam, quantized_model_init, and fsdp2 example #2698 introduced logic during the FusedAdam.__init__ such that if the TE model parameters are DTensor, then it will change the optimizer state to normal Tensor.
    • The reason is because empty_like does not pick up the correct dtype from DTensor (from in-line commentary), when the local data is QuantizedTensor. Note that Megatron-FSDP's main weights are FP32, not QuantizedTensor, so our code worked with the original FusedAdam.
  • When Megatron-FSDP (or Megatron-LM's distributed optimizer) performs its first optimizer.step(), Megatron-FSDP exposes FP32 DTensor main weights to the FusedAdam optimizer, and because of the above logic, normal Tensor optimizer states are constructed from the DTensor main weights.
  • Megatron-FSDP depends on DTensor optimizer states for DCP checkpointing of FusedAdam's state, because we employ un-even sharding. Instead, it now sees normal Tensors, and this may break our DCP integration and/or un-even DTensor metadata.

The fix is to keep the optimizer state in DTensor form if the model is in DTensor form, and localize or perform in-place operations to the local Tensor for all FusedAdam operations.

cspades and others added 7 commits March 24, 2026 12:01
Add Greptile bug-fixes.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
… re-sharding test.

Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 25, 2026

Copy link
Contributor

@pstjohn pstjohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR it would fail if you train on 4 ranks and load on 2 ranks, this adds a test for this.

(among other issues with mFSDP)

Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 25, 2026

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants