Skip to content

Commit

Permalink
Satisfy linter for test_model_builder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
maresb committed Aug 12, 2023
1 parent f8ab51d commit 7d6d0a0
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions tests/model_builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def fitted_model_instance(toy_X, toy_y):
"obs_error": 2,
}
model = test_ModelBuilder(
model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
model_config=model_config,
sampler_config=sampler_config,
test_parameter="test_paramter",
)
model.fit(toy_X)
return model
Expand Down Expand Up @@ -91,7 +93,7 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

# observed data
output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)
pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)

def _save_input_params(self, idata):
idata.attrs["test_paramter"] = json.dumps(self.test_parameter)
Expand Down Expand Up @@ -171,8 +173,10 @@ def test_empty_sampler_config_fit(toy_X, toy_y):


def test_fit(fitted_model_instance):
prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
pred = fitted_model_instance.predict(prediction_data["input"])
prediction_data = pd.DataFrame(
{"input": np.random.uniform(low=0, high=1, size=100)}
)
fitted_model_instance.predict(prediction_data["input"])
post_pred = fitted_model_instance.sample_posterior_predictive(
prediction_data["input"], extend_idata=True, combined=True
)
Expand All @@ -188,7 +192,8 @@ def test_fit_no_y(toy_X):


@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
sys.platform == "win32",
reason="Permissions for temp files not granted on windows CI.",
)
def test_predict(fitted_model_instance):
x_pred = np.random.uniform(low=0, high=1, size=100)
Expand Down

0 comments on commit 7d6d0a0

Please sign in to comment.