Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion docs/source/en/api/pipelines/hunyuandit.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,65 @@ HunyuanDiT has the following components:
* It uses a diffusion transformer as the backbone
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder

<Tip>

## Memory optimization
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

## Optimization

You can optimize the pipeline's runtime and memory consumption with torch.compile and feed-forward chunking. To learn about other optimization methods, check out the [Speed up inference](../../optimization/fp16) and [Reduce memory usage](../../optimization/memory) guides.

### Inference

Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.

First, load the pipeline:

```python
from diffusers import HunyuanDiTPipeline
import torch

pipeline = HunyuanDiTPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
).to("cuda")
```

Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:

```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
```

Finally, compile the components and run inference:

```python
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

image = pipeline(prompt="一个宇航员在骑马").images[0]
```

The [benchmark](https://gist.github.com/sayakpaul/29d3a14905cfcbf611fe71ebd22e9b23) results on a 80GB A100 machine are:

```bash
With torch.compile(): Average inference time: 12.470 seconds.
Without torch.compile(): Average inference time: 20.570 seconds.
```

### Memory optimization

By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.

Furthermore, you can use the [`~HunyuanDiT2DModel.enable_forward_chunking`] method to reduce memory usage. Feed-forward chunking runs the feed-forward layers in a transformer block in a loop instead of all at once. This gives you a trade-off between memory consumption and inference runtime.

```diff
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```diff
```py

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually diff in nature. Notice the "+" in the code.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I think it may be better to change it to python so users who copy/paste the code example don't get the "+" in it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our users will understand that difference given we make use of diff quite often throughout our docs.

+ pipeline.transformer.enable_forward_chunking(chunk_size=1, dim=1)
```


## HunyuanDiTPipeline

[[autodoc]] HunyuanDiTPipeline
Expand Down