Skip to content

Commit

Permalink
fix some warnings in tests (#483)
Browse files Browse the repository at this point in the history
* fix some warnings in tests

* add comment+complex/real arange test
  • Loading branch information
matthiasdiener authored Feb 6, 2024
1 parent c3fe047 commit ae41332
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 8 deletions.
4 changes: 3 additions & 1 deletion pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,9 @@ def arange(*args: Any, **kwargs: Any) -> Array:
stop = dtype.type(inf.stop)

from math import ceil
size = max(0, int(ceil((stop-start)/step)))
# np.real() suppresses "ComplexWarning: Casting complex values to real
# discards the imaginary part":
size = max(0, int(ceil((np.real(stop)-np.real(start))/np.real(step))))

from pymbolic.primitives import Variable
return IndexLambda(expr=start + Variable("_0") * step,
Expand Down
55 changes: 48 additions & 7 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_data_wrapper(ctx_factory):
# With name/shape
x_in = np.array([[1, 2], [3, 4], [5, 6]])
n = pt.make_size_param("n")
x = pt.make_data_wrapper(x_in, name="x", shape=(n, 2))
x = pt.make_data_wrapper(x_in, shape=(n, 2))
prog = pt.generate_loopy(x)
_, (x_out,) = prog(queue)
assert (x_out == x_in).all()
Expand Down Expand Up @@ -302,7 +302,8 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse):
if dtype in "FDG" and not_valid_in_complex:
continue
y = pt.make_data_wrapper(
y_orig.astype(dtype), name=f"y{dtype}")
y_orig.astype(dtype),
tags=frozenset([pt.tags.PrefixNamed(f"y{dtype}")]))
exprs[dtype] = pt_op(x_in, y)

prog = pt.generate_loopy(exprs)
Expand Down Expand Up @@ -352,14 +353,15 @@ def test_array_array_binary_arith(ctx_factory, which, reverse):
continue

x_in = x_orig.astype(first_dtype)
x = pt.make_data_wrapper(x_in, name="x")
x = pt.make_data_wrapper(x_in)

exprs = {}
for dtype in ARITH_DTYPES:
if dtype in "FDG" and not_valid_in_complex:
continue
y = pt.make_data_wrapper(
y_orig.astype(dtype), name=f"y{dtype}")
y_orig.astype(dtype),
tags=frozenset([pt.tags.PrefixNamed(f"y{dtype}")]))
exprs[dtype] = pt_op(x, y)

prog = pt.generate_loopy(exprs)
Expand Down Expand Up @@ -649,10 +651,10 @@ def test_full_zeros_ones(ctx_factory, dtype):
assert (t == 2).all()


def test_passsing_bound_arguments_raises(ctx_factory):
def test_passing_bound_arguments_raises(ctx_factory):
queue = cl.CommandQueue(ctx_factory())

x = pt.make_data_wrapper(np.ones(10), name="x")
x = pt.make_data_wrapper(np.ones(10), tags=frozenset([pt.tags.PrefixNamed("x")]))
prg = pt.generate_loopy(42*x)

with pytest.raises(ValueError):
Expand Down Expand Up @@ -1036,6 +1038,8 @@ def test_arange(ctx_factory):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

# {{{ Integer

from numpy.random import default_rng
rng = default_rng(seed=0)
for _ in range(30):
Expand All @@ -1051,6 +1055,43 @@ def test_arange(ctx_factory):

assert np.array_equal(pt_res, np_res)

# }}}

# {{{ Real

# generates '[0. ... 4.]':
np_res = np.arange(5, dtype=np.float64)
pt_res_sym = pt.arange(5, dtype=np.float64)

_, (pt_res,) = pt.generate_loopy(pt_res_sym)(cq)
print(np_res, pt_res)

assert np.array_equal(pt_res, np_res)

# }}}

# {{{ Complex

# generates '[]':
np_res = np.arange(5j, dtype=np.complex128)
pt_res_sym = pt.arange(5j, dtype=np.complex128)

_, (pt_res,) = pt.generate_loopy(pt_res_sym)(cq)
print(np_res, pt_res)

assert np.array_equal(pt_res, np_res)

# generates '[0.+0.j ... 4.+0.j]':
np_res = np.arange(5, dtype=np.complex128)
pt_res_sym = pt.arange(5, dtype=np.complex128)

_, (pt_res,) = pt.generate_loopy(pt_res_sym)(cq)
print(np_res, pt_res)

assert np.array_equal(pt_res, np_res)

# }}}


@pytest.mark.parametrize("which,num_args", ([("maximum", 2),
("minimum", 2),
Expand Down Expand Up @@ -1544,7 +1585,7 @@ def get_np_input_args():
pt_dag = kernel(pt, **{kw: pt.make_data_wrapper(arg)
for kw, arg in np_inputs.items()})

knl = pt.generate_loopy(pt_dag, options=lp.Options(write_cl=True))
knl = pt.generate_loopy(pt_dag, options=lp.Options(write_code=True))

_, (pt_result,) = knl(cq)

Expand Down

0 comments on commit ae41332

Please sign in to comment.