Skip to content

Commit

Permalink
relax the constraint for uint8 tensor dtype for the 2 display functio…
Browse files Browse the repository at this point in the history
…ns (#115)

* add a `image display` roxygen family

* relax the constraint on torch_tensor type in display functions
add tests

* add NEWS

* use a more secure function

* Revert "use a more secure function"

This reverts commit 9e0fe37.

* remove unexpected normalization

---------

Co-authored-by: Daniel Falbel <[email protected]>
  • Loading branch information
cregouby and dfalbel authored Sep 17, 2024
1 parent 0f6456f commit fe38f90
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 19 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# torchvision (development version)

- `tensor_image_display` and `tensor_image_browse` now accept all tensor_image dtypes. (#115, @cregouby)
- fix `transform_affine` help to remove confusion with `transforme_random_affine` help (#116, @cregouby)
- add message translation in french (#112, @cregouby)

Expand Down
37 changes: 27 additions & 10 deletions R/vision_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ NULL
#' @param padding amount of padding between batch images (default 2).
#' @param pad_value pixel value to use for padding.
#'
#' @family image display
#' @export
vision_make_grid <- function(tensor,
scale = TRUE,
Expand Down Expand Up @@ -91,6 +92,7 @@ vision_make_grid <- function(tensor,
#' tensor_image_browse(bboxed)
#' }
#' }
#' @family image display
#' @export
draw_bounding_boxes <- function(image,
boxes,
Expand Down Expand Up @@ -193,6 +195,7 @@ draw_bounding_boxes <- function(image,
#' masked_image <- draw_segmentation_masks(image, mask, alpha = 0.2)
#' tensor_image_browse(masked_image)
#' }
#' @family image display
#' @export
draw_segmentation_masks <- function(image,
masks,
Expand Down Expand Up @@ -261,6 +264,7 @@ draw_segmentation_masks <- function(image,
#' tensor_image_browse(keypoint_image)
#' }
#' }
#' @family image display
#' @export
draw_keypoints <- function(image,
keypoints,
Expand Down Expand Up @@ -314,19 +318,25 @@ draw_keypoints <- function(image,
#' Display image tensor
#'
#' Display image tensor onto the X11 device
#' @param image `torch_tensor()` of shape (1, W, H) for grayscale image or (3, W, H) for color image,
#' of type `torch_uint8()` to display
#' @param image `torch_tensor()` of shape (1, W, H) for grayscale image or (3, W, H) for
#' color image to display
#' @param animate support animations in the X11 display
#'
#' @family image display
#' @export
tensor_image_display <- function(image, animate = TRUE) {
stopifnot("`image` is expected to be of dtype torch_uint8" = image$dtype == torch::torch_uint8())
stopifnot("Pass individual images, not batches" = image$ndim == 3)
stopifnot("Only grayscale and RGB images are supported" = image$size(1) %in% c(1, 3))

img_to_draw <- image$permute(c(2, 3, 1))$to(device = "cpu", dtype = torch::torch_long()) %>% as.array
if (image$dtype == torch::torch_uint8()) {
img_to_draw <- image$permute(c(2, 3, 1))$to(device = "cpu", dtype = torch::torch_long()) %>%
as.array() / 255

png::writePNG(img_to_draw / 255) %>% magick::image_read() %>% magick::image_display(animate = animate)
} else {
img_to_draw <- image$permute(c(2, 3, 1))$to(device = "cpu") %>%
as.array()
}
png::writePNG(img_to_draw) %>% magick::image_read() %>% magick::image_display(animate = animate)

invisible(NULL)
}
Expand All @@ -335,19 +345,26 @@ tensor_image_display <- function(image, animate = TRUE) {
#' Display image tensor
#'
#' Display image tensor into browser
#' @param image `torch_tensor()` of shape (1, W, H) for grayscale image or (3, W, H) for color image,
#' of type `torch_uint8()` to display
#' @param image `torch_tensor()` of shape (1, W, H) for grayscale image or (3, W, H) for
#' color image to display
#' @param browser argument passed to [browseURL]
#'
#' @family image display
#' @export
tensor_image_browse <- function(image, browser = getOption("browser")) {
stopifnot("`image` is expected to be of dtype torch_uint8" = image$dtype == torch::torch_uint8())
stopifnot("Pass individual images, not batches" = image$ndim == 3)
stopifnot("Only grayscale and RGB images are supported" = image$size(1) %in% c(1, 3))

img_to_draw <- image$permute(c(2, 3, 1))$to(device = "cpu", dtype = torch::torch_long()) %>% as.array
if (image$dtype == torch::torch_uint8()) {
img_to_draw <- image$permute(c(2, 3, 1))$to(device = "cpu", dtype = torch::torch_long()) %>%
as.array() / 255

} else {
img_to_draw <- image$permute(c(2, 3, 1))$to(device = "cpu") %>%
as.array()
}

png::writePNG(img_to_draw / 255) %>% magick::image_read() %>% magick::image_browse(browser = browser)
png::writePNG(img_to_draw) %>% magick::image_read() %>% magick::image_browse(browser = browser)

invisible(NULL)
}
9 changes: 9 additions & 0 deletions man/draw_bounding_boxes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions man/draw_keypoints.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions man/draw_segmentation_masks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions man/tensor_image_browse.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions man/tensor_image_display.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions man/vision_make_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions tests/testthat/test-vision-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,20 @@ test_that("draw_keypoints works", {
test_that("tensor_image_browse works", {
skip_on_cran()
skip_on_ci()
# color image
# uint8 color image
image <- (255 - (torch::torch_randint(low = 1, high = 200, size = c(3, 360, 360))))$to(torch::torch_uint8())
expect_no_error(tensor_image_browse(image))
# grayscale image
# uint8 grayscale image
image <- (255 - (torch::torch_randint(low = 1, high = 200, size = c(1, 360, 360))))$to(torch::torch_uint8())
expect_no_error(tensor_image_browse(image))
# error cases : dtype
image_int16 <- image$to(torch::torch_int16())
expect_error(tensor_image_browse(image_int16), "dtype torch_uint8")

# float color image
image <- torch::torch_rand(size = c(3, 360, 360))
expect_no_error(tensor_image_browse(image))
# float grayscale image
image <- torch::torch_rand(size = c(1, 360, 360))
expect_no_error(tensor_image_browse(image))

# error cases : shape
image <- torch::torch_randint(low = 1, high = 200, size = c(4, 3, 360, 360))$to(torch::torch_uint8())
expect_error(tensor_image_browse(image), "individual images")
Expand Down

0 comments on commit fe38f90

Please sign in to comment.