Skip to content

Commit

Permalink
attempt to improve compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed Sep 4, 2024
1 parent e275fd3 commit 61e3dfa
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ def __init__(self,
if wait_event_queue_length is None:
wait_event_queue_length = 10

self._force_device_scalars = True
# Subclasses might still be using the old
# "force_devices_scalars: bool = False" interface, in which case we need
# to explicitly pass force_device_scalars=True in clone()
self._passed_force_device_scalars = force_device_scalars is not None

self._wait_event_queue_length = wait_event_queue_length
self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {}

Expand Down Expand Up @@ -260,8 +266,13 @@ def call_loopy(self, t_unit, **kwargs):
return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()}

def clone(self):
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length)
if self._passed_force_device_scalars:
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length,
force_device_scalars=True)
else:
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length)

# }}}

Expand Down

0 comments on commit 61e3dfa

Please sign in to comment.