Skip to content

Commit

Permalink
Prototype for interface
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoskelo committed Jul 8, 2024
1 parent e697300 commit 8071f8a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
34 changes: 33 additions & 1 deletion arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays"
dag = pt.transform.materialize_with_mpms(dag)
return dag

def pack_for_uq(self,*args):
def pack_for_uq(self,*args) -> dict:
"""
Args is a list of variable names and the realized input data that needs
to be packed for a parameter study or uncertainty quantification.
Expand Down Expand Up @@ -774,6 +774,38 @@ def pack_for_uq(self,*args):

return out


def unpack(self, data):
"""
Revert data to a sequence of outputs under the assumption that a specific variable
is held constant.
::arg:: data multidimensional array tagged with dimensions that are varying.
UQAxisTag will tag each specific axis that we are going to slice.
"""

ndim = len(data.axes)

out = {}


for i in range(ndim):
axis_tags = data.axes[i].tags_of_type(UQAxisTag)
if axis_tags:
# Now we need to split this data.
for j in range(len(data.axis[i])):
the_slice = [slice(None)] * ndim
the_slice[i] = j
if i in out.keys():
out[i].append(data[the_slice])
else:
out[i] = data[the_slice]
#yield data[the_slice]


return out


# }}}


Expand Down
11 changes: 5 additions & 6 deletions examples/uncertain_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
print("========================================================")
print(b)
print("========================================================")
breakpoint()

# Eq: z = x + y
# Assumptions: x and y are independently uncertain.
Expand All @@ -37,13 +36,13 @@
actx = PytatoPyOpenCLArrayContextUQ

out = actx.pack_for_uq(actx,"x", x, x1, x2, "y", y, y1, y2)
print("===============out======================")
print("===============OUT======================")
print(out)

breakpoint()



x = out["x"]
y = out["y"]

breakpoint()

x + y

0 comments on commit 8071f8a

Please sign in to comment.