Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelBuilder.load versatility improvements #210

Merged
merged 4 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc_experimental/linearmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def default_sampler_config(self):
"target_accept": 0.95,
}

@property
def _serializable_model_config(self) -> Dict:
return self.model_config

@property
def output_var(self):
return "y_hat"
Expand Down
22 changes: 20 additions & 2 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def sample_model(self, **kwargs):
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_posterior_predictive(idata))

self.set_idata_attrs(idata)
idata = self.set_idata_attrs(idata)
return idata

def set_idata_attrs(self, idata=None):
Expand Down Expand Up @@ -338,6 +338,10 @@ def set_idata_attrs(self, idata=None):
idata.attrs["version"] = self.version
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
# Only classes with non-dataset parameters will implement save_input_params
if hasattr(self, "_save_input_params"):
self._save_input_params(idata)
return idata

def save(self, fname: str) -> None:
"""
Expand Down Expand Up @@ -375,6 +379,17 @@ def save(self, fname: str) -> None:
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

@classmethod
def _convert_dims_to_tuple(cls, model_config: Dict) -> Dict:
for key in model_config:
if (
isinstance(model_config[key], dict)
and "dims" in model_config[key]
and isinstance(model_config[key]["dims"], list)
):
model_config[key]["dims"] = tuple(model_config[key]["dims"])
return model_config

@classmethod
def load(cls, fname: str):
"""
Expand Down Expand Up @@ -403,8 +418,10 @@ def load(cls, fname: str):
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
# needs to be converted, because json.loads was changing tuple to list
model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"]))
model = cls(
model_config=json.loads(idata.attrs["model_config"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
Expand Down Expand Up @@ -480,6 +497,7 @@ def fit(
combined_data = pd.concat([X_df, y], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore

return self.idata # type: ignore

def predict(
Expand Down
36 changes: 18 additions & 18 deletions pymc_experimental/tests/test_linearmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def fitted_linear_model_instance(toy_X, toy_y):
return model


@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_linear_model_instance):
model = fitted_linear_model_instance
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
model.save(temp.name)
model2 = LinearModel.load(temp.name)
assert model.idata.groups() == model2.idata.groups()

X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
pred1 = model.predict(X_pred, random_seed=423)
pred2 = model2.predict(X_pred, random_seed=423)
# Predictions should be identical
np.testing.assert_array_equal(pred1, pred2)
temp.close()


def test_save_without_fit_raises_runtime_error(toy_X, toy_y):
test_model = LinearModel()
with pytest.raises(RuntimeError):
Expand All @@ -83,24 +101,6 @@ def test_fit(fitted_linear_model_instance):
assert isinstance(post_pred, xr.DataArray)


@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_linear_model_instance):
model = fitted_linear_model_instance
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
model.save(temp.name)
model2 = LinearModel.load(temp.name)
assert model.idata.groups() == model2.idata.groups()

X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
pred1 = model.predict(X_pred, random_seed=423)
pred2 = model2.predict(X_pred, random_seed=423)
# Predictions should be identical
np.testing.assert_array_equal(pred1, pred2)
temp.close()


def test_predict(fitted_linear_model_instance):
model = fitted_linear_model_instance
X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
Expand Down
72 changes: 50 additions & 22 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import hashlib
import json
import sys
import tempfile
from typing import Dict
Expand Down Expand Up @@ -43,29 +44,35 @@ def toy_y(toy_X):
@pytest.fixture(scope="module")
def fitted_model_instance(toy_X, toy_y):
sampler_config = {
"draws": 500,
"tune": 300,
"draws": 100,
"tune": 100,
"chains": 2,
"target_accept": 0.95,
}
model_config = {
"a": {"loc": 0, "scale": 10},
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
model = test_ModelBuilder(
model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
)
model.fit(toy_X)
return model


class test_ModelBuilder(ModelBuilder):
def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
self.test_parameter = test_parameter
super().__init__(model_config=model_config, sampler_config=sampler_config)

_model_type = "LinearModel"
_model_type = "test_model"
version = "0.1"

def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
coords = {"numbers": np.arange(len(X))}
self.generate_and_preprocess_model_data(X, y)
with pm.Model() as self.model:
with pm.Model(coords=coords) as self.model:
if model_config is None:
model_config = self.default_model_config
x = pm.MutableData("x", self.X["input"].values)
Expand All @@ -79,13 +86,16 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
obs_error = model_config["obs_error"]

# priors
a = pm.Normal("a", a_loc, sigma=a_scale)
a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"])
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

# observed data
output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)

def _save_input_params(self, idata):
idata.attrs["test_paramter"] = json.dumps(self.test_parameter)

@property
def output_var(self):
return "output"
Expand All @@ -107,7 +117,7 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
@property
def default_model_config(self) -> Dict:
return {
"a": {"loc": 0, "scale": 10},
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}
Expand All @@ -122,6 +132,38 @@ def default_sampler_config(self) -> Dict:
}


def test_save_input_params(fitted_model_instance):
assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'


def test_save_load(fitted_model_instance):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
assert fitted_model_instance.id == test_builder2.id
x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
assert pred1.shape == pred2.shape
temp.close()


def test_convert_dims_to_tuple(fitted_model_instance):
model_config = {
"a": {
"loc": 0,
"scale": 10,
"dims": [
"x",
],
},
}
converted_model_config = fitted_model_instance._convert_dims_to_tuple(model_config)
assert converted_model_config["a"]["dims"] == ("x",)
michaelraczycki marked this conversation as resolved.
Show resolved Hide resolved


def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
if check_idata:
assert fitted_model_instance.idata is not None
Expand Down Expand Up @@ -162,20 +204,6 @@ def test_fit_no_y(toy_X):
@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_model_instance):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()

x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
assert pred1.shape == pred2.shape
temp.close()


def test_predict(fitted_model_instance):
x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
Expand Down