Skip to content

Commit

Permalink
Update to new ert api for no refcase
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Aug 15, 2023
1 parent cda09c4 commit 0d10b77
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 33 deletions.
2 changes: 1 addition & 1 deletion semeio/communication/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def publish(self, namespace, data):

all_data.append(data)
with open(output_file, "w", encoding="utf-8") as f_handle:
json.dump(all_data, f_handle)
json.dump(all_data, f_handle, default=str)

def publish_msg(self, namespace, msg):
output_file = self._prepare_output_file(namespace)
Expand Down
20 changes: 17 additions & 3 deletions semeio/workflows/correlated_observations_scaling/job_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from datetime import datetime

import configsuite
from configsuite import MetaKeys as MK
Expand All @@ -21,10 +22,12 @@ def _min_length(value):

@configsuite.validator_msg("Minimum value of index must be >= 0")
def _min_value(value):
return value >= 0
if isinstance(value, int):
return value >= 0
return True


_NUM_CONVERT_MSG = "Will go through the input and try to convert to list of int"
_NUM_CONVERT_MSG = "Will go through the input and try to convert to list"


@configsuite.transformation_msg(_NUM_CONVERT_MSG)
Expand All @@ -33,6 +36,8 @@ def _to_int_list(value):
if isinstance(value, int):
return [value]
if isinstance(value, (list, tuple)):
if all(isinstance(val, datetime) for val in value):
return value
value = ",".join([str(x) for x in value])
return _realize_list(value)

Expand Down Expand Up @@ -146,6 +151,15 @@ def _CALCULATE_KEYS_key_not_empty_list(content):
on "FOPR", but only update the scaling on indices "50-100".
"""


@configsuite.validator_msg("int or datetime")
def _is_int_or_datetime(value):
return isinstance(value, (datetime, int))


IntOrDate = configsuite.BasicType("int_or_datetime", _is_int_or_datetime)


_KEYS_SCHEMA = {
MK.ElementValidators: (_CALCULATE_KEYS_key_not_empty_list,),
MK.Type: types.List,
Expand All @@ -167,7 +181,7 @@ def _CALCULATE_KEYS_key_not_empty_list(content):
"(1,2,4-6,14-15) ->[1, 2, 4, 5, 6, 14, 15]",
MK.Content: {
MK.Item: {
MK.Type: types.Integer,
MK.Type: IntOrDate,
MK.ElementValidators: (_min_value,),
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ def _update_scaling(obs, scale_factor, obs_list):
for event in obs_list:
obs_vector = obs[event.key]
index_list = (
event.index
if event.index
else [x - 1 for x in obs_vector.observations.keys()]
event.index if event.index else list(obs_vector.observations.keys())
)
for step, obs_node in obs_vector.observations.items():
if obs_vector.observation_type.name == "SUMMARY_OBS":
if step - 1 in index_list:
if step in index_list:
obs_node.std_scaling = scale_factor
else:
obs_node.std_scaling[event.active_list] = scale_factor
Expand Down
2 changes: 1 addition & 1 deletion semeio/workflows/spearman_correlation_job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def spearman_job(
zip(
clusters,
columns.get_level_values(0),
columns.get_level_values("data_index"),
columns.get_level_values("key_index"),
)
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
],
install_requires=[
"ecl",
"ert>=5.1.0-b8",
"ert>=6.0.0-rc0",
"configsuite>=0.6",
"numpy",
"pandas>1.3.0",
Expand Down
11 changes: 0 additions & 11 deletions tests/communication/unit/test_file_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,6 @@ def test_file_reporter_publish_valid_json(data, tmpdir):
assert loaded_data == [data]


def test_file_reporter_publish_invalid_json(tmpdir):
tmpdir.chdir()
namespace = "data"
data = json # The json module is not JSON serializable...

reporter = FileReporter(os.getcwd())

with pytest.raises(TypeError):
reporter.publish(namespace, data)


def test_file_reporter_publish_multiple_json(tmpdir):
tmpdir.chdir()
namespace = "some_data"
Expand Down
30 changes: 28 additions & 2 deletions tests/jobs/correlated_observations_scaling/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from datetime import datetime
from unittest.mock import MagicMock

import numpy as np
Expand Down Expand Up @@ -144,7 +145,19 @@ def test_main_entry_point_history_data_calc(snake_oil_facade, config, expected_r


def test_main_entry_point_history_data_calc_subset(snake_oil_facade):
config = {"CALCULATE_KEYS": {"keys": [{"key": "FOPR", "index": [10, 20]}]}}
config = {
"CALCULATE_KEYS": {
"keys": [
{
"key": "FOPR",
"index": [
datetime(2010, 4, 20),
datetime(2010, 7, 29),
],
}
]
}
}
obs = snake_oil_facade.get_observations()
obs_vector = obs["FOPR"]

Expand Down Expand Up @@ -194,7 +207,20 @@ def test_main_entry_point_sum_data_update(snake_oil_facade, monkeypatch):
def test_main_entry_point_shielded_data(monkeypatch):
ert = LibresFacade.from_config_file("snake_oil.ert")
cos_config = {
"CALCULATE_KEYS": {"keys": [{"key": "FOPR", "index": [1, 2, 3, 4, 5]}]}
"CALCULATE_KEYS": {
"keys": [
{
"key": "FOPR",
"index": [
datetime(2010, 1, 20),
datetime(2010, 1, 30),
datetime(2010, 2, 9),
datetime(2010, 2, 19),
datetime(2010, 3, 1),
],
}
]
}
}

obs = ert.get_observations()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_valid_config_setup(valid_config):
},
[
(
"'Will go through the input and try to convert to list of int' "
"'Will go through the input and try to convert to list' "
"failed on input '[-1, 2, 3]' with error 'Elements can not be "
"negative, neither singletons nor in range, got: -1'"
),
Expand Down
21 changes: 13 additions & 8 deletions tests/jobs/test_scale_observations/test_scale_observations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pytest

from semeio.workflows.correlated_observations_scaling.update_scaling import (
Expand All @@ -16,25 +18,28 @@ def fixture_snake_oil_obs(snake_oil_facade):
return snake_oil_facade.get_observations()


@pytest.mark.parametrize("index_list", [None, [0, 1, 2, 3]])
@pytest.mark.parametrize(
"index_list",
[None, [datetime(2010, 1, 10), datetime(2010, 1, 30), datetime(2010, 2, 9, 0, 0)]],
)
def test_scale_history_summary_obs(snake_oil_obs, index_list):
scale_observations(snake_oil_obs, 1.2345, [Config("FOPR", index_list)])

obs_vector = snake_oil_obs["FOPR"]
for index, node in enumerate(obs_vector):
if not index_list or index in index_list:
assert node.std_scaling == 1.2345, f"index: {index}"
for date, node in obs_vector.observations.items():
if not index_list or date in index_list:
assert node.std_scaling == 1.2345, f"index: {date}"
else:
assert node.std_scaling == 1.0, f"index: {index}"
assert node.std_scaling == 1.0, f"index: {date}"


@pytest.mark.parametrize("index_list", [None, [35]])
@pytest.mark.parametrize("index_list", [None, [datetime(2010, 12, 26)]])
def test_scale_summary_obs(snake_oil_obs, index_list):
scale_observations(snake_oil_obs, 1.2345, [Config("WOPR_OP1_36", index_list)])

obs_vector = snake_oil_obs["WOPR_OP1_36"]
node = obs_vector.observations[36]
assert node.std_scaling == 1.2345, f"index: {36}"
node = obs_vector.observations[datetime(2010, 12, 26)]
assert node.std_scaling == 1.2345, f"index: {datetime(2010, 12, 26)}"


@pytest.mark.parametrize("index_list", [None, [400, 800]])
Expand Down
2 changes: 1 addition & 1 deletion tests/legacy_test_data/poly_normal/poly-ies.ert
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ NUM_REALIZATIONS 100
MIN_REALIZATIONS 1

GEN_KW COEFFS coeff.tmpl coeffs.json coeff_priors
GEN_DATA POLY_RES RESULT_FILE:poly_%d.out REPORT_STEPS:0 INPUT_FORMAT:ASCII
GEN_DATA POLY_RES RESULT_FILE:poly.out

INSTALL_JOB poly_eval POLY_EVAL
SIMULATION_JOB poly_eval
Expand Down

0 comments on commit 0d10b77

Please sign in to comment.