From 15c166ac127db5c8d1541b3485ef5730d34bb68a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Nov 2023 09:51:43 +0100 Subject: [PATCH] refactor to_pil_image and align array with tensor inputs (#8097) Co-authored-by: Nicolas Hug --- test/test_transforms.py | 10 ++++--- torchvision/transforms/functional.py | 43 +++++++++------------------- 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 7c92baa9f5c..16d0e7e5d94 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self): @pytest.mark.parametrize( "img_data, expected_mode", [ - (torch.Tensor(4, 4, 1).uniform_().numpy(), "F"), + (torch.Tensor(4, 4, 1).uniform_().numpy(), "L"), (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"), (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"), (torch.IntTensor(4, 4, 1).random_().numpy(), "I"), @@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) assert img.mode == expected_mode + if np.issubdtype(img_data.dtype, np.floating): + img_data = (img_data * 255).astype(np.uint8) # note: we explicitly convert img's dtype because pytorch doesn't support uint16 # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype)) @@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe @pytest.mark.parametrize( "img_data, expected_mode", [ - (torch.Tensor(4, 4).uniform_().numpy(), "F"), + (torch.Tensor(4, 4).uniform_().numpy(), "L"), (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"), (torch.ShortTensor(4, 4).random_().numpy(), "I;16"), (torch.IntTensor(4, 4).random_().numpy(), "I"), @@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) assert img.mode == expected_mode + if np.issubdtype(img_data.dtype, np.floating): + img_data = (img_data * 255).astype(np.uint8) np.testing.assert_allclose(img_data, img) @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) @@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self): trans(np.ones([4, 4, 1], np.uint16)) with pytest.raises(TypeError, match=reg_msg): trans(np.ones([4, 4, 1], np.uint32)) - with pytest.raises(TypeError, match=reg_msg): - trans(np.ones([4, 4, 1], np.float64)) with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d176e00a8da..7cbe2d99071 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(to_pil_image) - if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): + if isinstance(pic, torch.Tensor): + if pic.ndim == 3: + pic = pic.permute((1, 2, 0)) + pic = pic.numpy(force=True) + elif not isinstance(pic, np.ndarray): raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") - elif isinstance(pic, torch.Tensor): - if pic.ndimension() not in {2, 3}: - raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.") - - elif pic.ndimension() == 2: - # if 2D image, add channel dimension (CHW) - pic = pic.unsqueeze(0) - - # check number of channels - if pic.shape[-3] > 4: - raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.") - - elif isinstance(pic, np.ndarray): - if pic.ndim not in {2, 3}: - raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.") - - elif pic.ndim == 2: - # if 2D image, add channel dimension (HWC) - pic = np.expand_dims(pic, 2) + if pic.ndim == 2: + # if 2D image, add channel dimension (HWC) + pic = np.expand_dims(pic, 2) + if pic.ndim != 3: + raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.") - # check number of channels - if pic.shape[-1] > 4: - raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.") + if pic.shape[-1] > 4: + raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.") npimg = pic - if isinstance(pic, torch.Tensor): - if pic.is_floating_point() and mode != "F": - pic = pic.mul(255).byte() - npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) - if not isinstance(npimg, np.ndarray): - raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}") + if np.issubdtype(npimg.dtype, np.floating) and mode != "F": + npimg = (npimg * 255).astype(np.uint8) if npimg.shape[2] == 1: expected_mode = None