diff --git a/tests/metrics/test_ssim_metric.py b/tests/metrics/test_ssim_metric.py index d79107e999..1659d2ce4a 100644 --- a/tests/metrics/test_ssim_metric.py +++ b/tests/metrics/test_ssim_metric.py @@ -21,7 +21,7 @@ class TestSSIMMetric(unittest.TestCase): - def test2d_gaussian(self): + def test_2d_gaussian(self): set_determinism(0) preds = torch.abs(torch.randn(2, 3, 16, 16)) target = torch.abs(torch.randn(2, 3, 16, 16)) @@ -32,9 +32,9 @@ def test2d_gaussian(self): metric(preds, target) result = metric.aggregate() expected_value = 0.045415 - self.assertTrue(expected_value - result.item() < 0.000001) + self.assertTrue(abs(expected_value - result.item()) < 0.000001) - def test2d_uniform(self): + def test_2d_uniform(self): set_determinism(0) preds = torch.abs(torch.randn(2, 3, 16, 16)) target = torch.abs(torch.randn(2, 3, 16, 16)) @@ -45,9 +45,9 @@ def test2d_uniform(self): metric(preds, target) result = metric.aggregate() expected_value = 0.050103 - self.assertTrue(expected_value - result.item() < 0.000001) + self.assertTrue(abs(expected_value - result.item()) < 0.000001) - def test3d_gaussian(self): + def test_3d_gaussian(self): set_determinism(0) preds = torch.abs(torch.randn(2, 3, 16, 16, 16)) target = torch.abs(torch.randn(2, 3, 16, 16, 16))