From a5eb20a319c79a700486ce78d7b178cdf24f78f1 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 23 Apr 2020 11:41:56 -0500 Subject: [PATCH] port Subset to cuda --- pyop2/gpu/cuda.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/pyop2/gpu/cuda.py b/pyop2/gpu/cuda.py index 93499bbab..32fbbe8a8 100644 --- a/pyop2/gpu/cuda.py +++ b/pyop2/gpu/cuda.py @@ -100,6 +100,17 @@ def _kernel_args_(self): return (m_gpu,) +class Subset(Subset): + """ + ExtrudedSet for GPU. + """ + @cached_property + def _kernel_args_(self): + m_gpu = cuda.mem_alloc(int(self._indices.nbytes)) + cuda.memcpy_htod(m_gpu, self._indices) + return self._superset._kernel_args_ + (m_gpu, ) + + class Dat(petsc_Dat): """ Dat for GPU. @@ -318,8 +329,11 @@ def code_to_compile(self): builder.set_kernel(self._kernel) wrapper = generate(builder) + print('Compiling...', wrapper.name) code, processed_program, args_to_make_global = generate_gpu_kernel(wrapper, self.args, self.argshapes) + + print(code) for i, arg_to_make_global in enumerate(args_to_make_global): numpy.save(self.ith_added_global_arg_i(i), arg_to_make_global) @@ -386,6 +400,7 @@ def argshapes(self): argshapes = ((), ()) if self._iterset._argtypes_: # TODO: verify that this bogus value doesn't affect anyone. + # raise NotImplementedError() argshapes += ((), ) for arg in self._args: @@ -605,7 +620,7 @@ def insn_needs_atomic(insn): raise ValueError("gpu_strategy can be 'scpt'," " 'user_specified_tile' or 'auto_tile'.") elif program.name in ["wrap_zero", "wrap_expression_kernel", - "wrap_pyop2_kernel_uniform_extrusion", + "wrap_expression", "wrap_pyop2_kernel_uniform_extrusion", "wrap_form_cell_integral_otherwise", ]: from pyop2.gpu.snpt import snpt_transform