Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Oct 7, 2024
1 parent 6c06864 commit b0dcf76
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions lib/bumblebee/utils/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,22 @@ defmodule Bumblebee.Utils.Nx do
iex> [first, second] = Bumblebee.Utils.Nx.batch_to_list(outputs)
iex> first.x
#Nx.Tensor<
s64[2]
s32[2]
[0, 0]
>
iex> second.x
#Nx.Tensor<
s64[2]
s32[2]
[1, 1]
>
iex> first.y
#Nx.Tensor<
s64
s32
0
>
iex> second.y
#Nx.Tensor<
s64
s32
1
>
Expand Down Expand Up @@ -122,7 +122,7 @@ defmodule Bumblebee.Utils.Nx do
iex> result = Bumblebee.Utils.Nx.composite_concatenate(left, right)
iex> result.x
#Nx.Tensor<
s64[4][2]
s32[4][2]
[
[0, 0],
[1, 1],
Expand All @@ -132,7 +132,7 @@ defmodule Bumblebee.Utils.Nx do
>
iex> result.y
#Nx.Tensor<
s64[4]
s32[4]
[0, 1, 2, 3]
>
Expand Down Expand Up @@ -164,7 +164,7 @@ defmodule Bumblebee.Utils.Nx do
iex> result = Bumblebee.Utils.Nx.composite_unflatten_batch(output, 2)
iex> result.x
#Nx.Tensor<
s64[2][1][2]
s32[2][1][2]
[
[
[0, 0]
Expand All @@ -176,7 +176,7 @@ defmodule Bumblebee.Utils.Nx do
>
iex> result.y
#Nx.Tensor<
s64[2][1]
s32[2][1]
[
[0],
[1]
Expand Down Expand Up @@ -205,12 +205,12 @@ defmodule Bumblebee.Utils.Nx do
iex> result = Bumblebee.Utils.Nx.composite_flatten_batch(output)
iex> result.x
#Nx.Tensor<
s64[4]
s32[4]
[0, 0, 1, 1]
>
iex> result.y
#Nx.Tensor<
s64[2]
s32[2]
[0, 1]
>
Expand Down Expand Up @@ -249,7 +249,7 @@ defmodule Bumblebee.Utils.Nx do
iex> idx = Nx.tensor([[1, 0], [1, 1]])
iex> Bumblebee.Utils.Nx.batched_take(t, idx)
#Nx.Tensor<
s64[2][2][2]
s32[2][2][2]
[
[
[2, 2],
Expand Down Expand Up @@ -348,7 +348,7 @@ defmodule Bumblebee.Utils.Nx do
iex> x = Nx.tensor([[1, 2], [3, 4]])
iex> Bumblebee.Utils.Nx.repeat_interleave(x, 2)
#Nx.Tensor<
s64[4][2]
s32[4][2]
[
[1, 2],
[1, 2],
Expand Down Expand Up @@ -387,7 +387,7 @@ defmodule Bumblebee.Utils.Nx do
iex> x = Nx.tensor([[1, 1], [2, 2], [3, 3], [4, 4]])
iex> Bumblebee.Utils.Nx.chunked_take(x, 2, Nx.tensor([1, 0]))
#Nx.Tensor<
s64[2][2]
s32[2][2]
[
[2, 2],
[3, 3]
Expand Down Expand Up @@ -427,7 +427,7 @@ defmodule Bumblebee.Utils.Nx do
iex> x = Nx.iota({3, 3})
iex> Bumblebee.Utils.Nx.roll(x, shifts: [1], axes: [0])
#Nx.Tensor<
s64[3][3]
s32[3][3]
[
[6, 7, 8],
[0, 1, 2],
Expand All @@ -438,7 +438,7 @@ defmodule Bumblebee.Utils.Nx do
iex> x = Nx.iota({3, 3})
iex> Bumblebee.Utils.Nx.roll(x, shifts: [-1], axes: [0])
#Nx.Tensor<
s64[3][3]
s32[3][3]
[
[3, 4, 5],
[6, 7, 8],
Expand All @@ -449,7 +449,7 @@ defmodule Bumblebee.Utils.Nx do
iex> x = Nx.iota({3, 3})
iex> Bumblebee.Utils.Nx.roll(x, shifts: [1, 2], axes: [0, 1])
#Nx.Tensor<
s64[3][3]
s32[3][3]
[
[7, 8, 6],
[1, 2, 0],
Expand Down

0 comments on commit b0dcf76

Please sign in to comment.