Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
*.egg-info
build
/outputs
/checkpoints
/checkpoints
__pycache__
6 changes: 4 additions & 2 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def run_txt2img(
use_identity_guider=not version_dict["is_guided"]
)

num_samples = num_rows * num_cols
# num_samples = num_rows * num_cols
num_samples = 1

if st.button("Sample"):
if True:
st.write(f"**Model I:** {version}")
out = do_sample(
state["model"],
Expand Down Expand Up @@ -307,6 +308,7 @@ def apply_refiner(
)
else:
raise ValueError(f"unknown mode {mode}")

if isinstance(out, (tuple, list)):
samples, samples_z = out
else:
Expand Down
5 changes: 4 additions & 1 deletion scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torchvision import transforms
from torchvision.utils import make_grid
from safetensors.torch import load_file as load_safetensors
from pytorch_lightning import seed_everything

from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
Expand Down Expand Up @@ -206,6 +207,7 @@ def init_save_locally(_dir, init_value: bool = False):
else:
save_path = None

return True, "/home/patrick/images/sgm_test"
return save_locally, save_path


Expand Down Expand Up @@ -513,7 +515,8 @@ def do_sample(
additional_model_inputs[k] = batch[k]

shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
seed_everything(0)
randn = torch.randn(shape, device="cuda")

def denoiser(input, sigma, c):
return model.denoiser(
Expand Down