From 8071f8aaae9edc196da3b75e437bd319f5fa9a2b Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 8 Jul 2024 09:22:12 -0500 Subject: [PATCH] Prototype for interface --- arraycontext/impl/pytato/__init__.py | 34 +++++++++++++++++++++++++++- examples/uncertain_prop.py | 11 ++++----- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 58201b6c..09c3def0 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -720,7 +720,7 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays" dag = pt.transform.materialize_with_mpms(dag) return dag - def pack_for_uq(self,*args): + def pack_for_uq(self,*args) -> dict: """ Args is a list of variable names and the realized input data that needs to be packed for a parameter study or uncertainty quantification. @@ -774,6 +774,38 @@ def pack_for_uq(self,*args): return out + + def unpack(self, data): + """ + Revert data to a sequence of outputs under the assumption that a specific variable + is held constant. + + ::arg:: data multidimensional array tagged with dimensions that are varying. + UQAxisTag will tag each specific axis that we are going to slice. + """ + + ndim = len(data.axes) + + out = {} + + + for i in range(ndim): + axis_tags = data.axes[i].tags_of_type(UQAxisTag) + if axis_tags: + # Now we need to split this data. + for j in range(len(data.axis[i])): + the_slice = [slice(None)] * ndim + the_slice[i] = j + if i in out.keys(): + out[i].append(data[the_slice]) + else: + out[i] = data[the_slice] + #yield data[the_slice] + + + return out + + # }}} diff --git a/examples/uncertain_prop.py b/examples/uncertain_prop.py index 35fa8e2e..67afeaea 100644 --- a/examples/uncertain_prop.py +++ b/examples/uncertain_prop.py @@ -19,7 +19,6 @@ print("========================================================") print(b) print("========================================================") -breakpoint() # Eq: z = x + y # Assumptions: x and y are independently uncertain. @@ -37,13 +36,13 @@ actx = PytatoPyOpenCLArrayContextUQ out = actx.pack_for_uq(actx,"x", x, x1, x2, "y", y, y1, y2) -print("===============out======================") +print("===============OUT======================") print(out) -breakpoint() - - - +x = out["x"] +y = out["y"] +breakpoint() +x + y