unet time embedding activation function#3048
Conversation
| if self.time_embed_act is not None: | ||
| emb = self.time_embed_act(emb) | ||
|
|
There was a problem hiding this comment.
optional activation of time embeddings once at at the beginning of the unet
There was a problem hiding this comment.
Out of curiosity.
Is it being used in the private fork?
|
The documentation is not available anymore as the PR was closed or merged. |
55412f3 to
df4eb1b
Compare
| if act_fn == "swish": | ||
| self.time_embed_act = lambda x: F.silu(x) | ||
| elif act_fn == "mish": | ||
| self.time_embed_act = nn.Mish() | ||
| elif act_fn == "silu": | ||
| self.time_embed_act = nn.SiLU() |
There was a problem hiding this comment.
Can't we do?
if act_fn in ["swish", "silu"]:
self.time_embed_act = nn.SiLU()There was a problem hiding this comment.
Yes I would hope we could :) This is how it's done in a few other places in the code base so I'd like to leave it this way for now and do a follow up including a refactor of all the dispatches to the different activation functions
patrickvonplaten
left a comment
There was a problem hiding this comment.
Ok for me! BTW I't totally fine with creating a get_act_fn function and an activation file as we do in transformes here: https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
9ae2f30 to
0660fe0
Compare
f4a5a17 to
e309542
Compare
* unet time embedding activation function * typo act_fn -> time_embedding_act_fn * flatten conditional
* unet time embedding activation function * typo act_fn -> time_embedding_act_fn * flatten conditional
* unet time embedding activation function * typo act_fn -> time_embedding_act_fn * flatten conditional
* unet time embedding activation function * typo act_fn -> time_embedding_act_fn * flatten conditional
See PR comments