Skip to content

Commit

Permalink
Reduced number of graphs for compiled resize (pytorch#8108)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
vfdev-5 and NicolasHug authored Nov 22, 2023
1 parent 893b4ab commit ab4c102
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ def resize(
return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)


# This is an internal helper method for resize_image. We should put it here instead of keeping it
# inside resize_image due to torchscript.
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks on eager, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
def _do_native_uint8_resize_on_cpu(interpolation: InterpolationMode) -> bool:
if interpolation == InterpolationMode.BILINEAR:
if torch._dynamo.is_compiling():
return True
else:
return "AVX2" in torch.backends.cpu.get_cpu_capability()

return interpolation == InterpolationMode.BICUBIC


@_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, tv_tensors.Image)
def resize_image(
Expand Down Expand Up @@ -215,21 +229,16 @@ def resize_image(
if (new_height, new_width) == (old_height, old_width):
return image
elif numel > 0:
image = image.reshape(-1, num_channels, old_height, old_width)

dtype = image.dtype
acceptable_dtypes = [torch.float32, torch.float64]
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif image.device.type == "cpu":
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
interpolation == InterpolationMode.BICUBIC
):
if _do_native_uint8_resize_on_cpu(interpolation):
acceptable_dtypes.append(torch.uint8)

image = image.reshape(-1, num_channels, old_height, old_width)
strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
Expand Down

0 comments on commit ab4c102

Please sign in to comment.