From fb3f7c901d5583d73acb1b4fe631123677d99b49 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 1 Apr 2024 07:04:26 +0000 Subject: [PATCH] update --- tests/pipelines/test_pipelines_common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d7f0c6baa339..c6b89c285b94 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -131,11 +131,15 @@ def test_freeu_enabled(self): inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False + inputs["output_type"] = "np" + output = pipe(**inputs)[0] pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False + inputs["output_type"] = "np" + output_freeu = pipe(**inputs)[0] assert not np.allclose( @@ -150,6 +154,8 @@ def test_freeu_disabled(self): inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False + inputs["output_type"] = "np" + output = pipe(**inputs)[0] pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4) @@ -162,6 +168,8 @@ def test_freeu_disabled(self): inputs = self.get_dummy_inputs(torch_device) inputs["return_dict"] = False + inputs["output_type"] = "np" + output_no_freeu = pipe(**inputs)[0] assert np.allclose( output, output_no_freeu, atol=1e-2