-
Notifications
You must be signed in to change notification settings - Fork 6.9k
UNet2DOutput becomes nested when wrapped with accelerate #560
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When either UNet2DModel or UNet2DConditionModel are prepared with accelerate, their outputs become nested, i.e. to get the sample you have to do outputs.sample['sample'] instead of just outputs.sample.
However, it works as expected with return_dict=False.
cc @patrickvonplaten @patil-suraj @sgugger
Reproduction
- Regular outputs:
from accelerate import Accelerator
from diffusers import UNet2DModel
model = UNet2DModel(
sample_size=16,
in_channels=3,
out_channels=3,
layers_per_block=1,
block_out_channels=(128,),
down_block_types=("DownBlock2D",),
up_block_types=("UpBlock2D",),
)
model = model.cuda()
x = torch.randn((1, 3, 16, 16)).to(model.device)
t = torch.tensor([0]).to(model.device)
print(model(x, t))UNet2DOutput(sample=tensor([[[[ 6.6783e-02, 1.1064e-01, -3.1808e-01, -2.7287e-01, 7.0199e-02,
-1.7103e-01, 4.0218e-01, 1.7775e-01, 6.2583e-01, -1.4165e-01,
9.0453e-02, -8.6686e-02, -3.2533e-01, 3.4424e-02, 4.2162e-02,
8.4619e-02],
...
- Outputs after
accelerator.prepare:
accelerator = Accelerator()
model = accelerator.prepare(model)
print(model(x, t))UNet2DOutput(sample={'sample': tensor([[[[ 6.6467e-02, 1.1072e-01, -3.1836e-01, -2.7295e-01, 6.9458e-02,
-1.7090e-01, 4.0234e-01, 1.7761e-01, 6.2598e-01, -1.4160e-01,
9.0271e-02, -8.6182e-02, -3.2568e-01, 3.4332e-02, 4.2419e-02,
8.4656e-02],
Logs
No response
System Info
diffusers: 0.3.0
accelerate: 0.12.0
torch: 1.12.1+cu113
single GPU, colab
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
Type
Fields
Give feedbackNo fields configured for issues without a type.