diff --git a/pennylane/data/base/dataset.py b/pennylane/data/base/dataset.py index 3d80b675bed..dedb43782df 100644 --- a/pennylane/data/base/dataset.py +++ b/pennylane/data/base/dataset.py @@ -250,6 +250,7 @@ def identifiers(self) -> typing.Mapping[str, str]: # pylint: disable=function-r return { attr_name: getattr(self, attr_name) for attr_name in self.info.get("identifiers", self.info.get("params", [])) + if attr_name in self.bind } @property @@ -323,7 +324,7 @@ def write( values are "w-" (create, fail if file exists), "w" (create, overwrite existing), and "a" (append existing, create if doesn't exist). Default is "w-". attributes: Optional list of attributes to copy. If None, all attributes - will be copied. + will be copied. Note that identifiers will always be copied. overwrite: Whether to overwrite attributes that already exist in this dataset. """ @@ -337,6 +338,12 @@ def write( hdf5.copy_all(self.bind, dest.bind, *attributes, on_conflict=on_conflict) + missing_identifiers = [ + identifier for identifier in self.identifiers if not hasattr(dest, identifier) + ] + if missing_identifiers: + hdf5.copy_all(self.bind, dest.bind, *missing_identifiers) + def _init_bind( self, data_name: Optional[str] = None, identifiers: Optional[Tuple[str, ...]] = None ): diff --git a/pennylane/data/data_manager/__init__.py b/pennylane/data/data_manager/__init__.py index 764f61d6804..1c19ef052e7 100644 --- a/pennylane/data/data_manager/__init__.py +++ b/pennylane/data/data_manager/__init__.py @@ -97,6 +97,23 @@ def _download_dataset( f.write(resp.content) +def _validate_attributes(data_struct: dict, data_name: str, attributes: typing.Iterable[str]): + """Checks that ``attributes`` contains only valid attributes for the given + ``data_name``. If any attributes do not exist, raise a ValueError.""" + invalid_attributes = [ + attr for attr in attributes if attr not in data_struct[data_name]["attributes"] + ] + if not invalid_attributes: + return + + if len(invalid_attributes) == 1: + values_err = f"'{invalid_attributes[0]}' is an invalid attribute for '{data_name}'" + else: + values_err = f"{invalid_attributes} are invalid attributes for '{data_name}'" + + raise ValueError(f"{values_err}. Valid attributes are: {data_struct[data_name]['attributes']}") + + def load( # pylint: disable=too-many-arguments data_name: str, attributes: Optional[typing.Iterable[str]] = None, @@ -186,14 +203,18 @@ def load( # pylint: disable=too-many-arguments >>> print(circuit()) -1.0791430411076344 """ + foldermap = _get_foldermap() + data_struct = _get_data_struct() + params = format_params(**params) + if attributes: + _validate_attributes(data_struct, data_name, attributes) + folder_path = Path(folder_path) if cache_dir and not Path(cache_dir).is_absolute(): cache_dir = folder_path / cache_dir - foldermap = _get_foldermap() - data_paths = [data_path for _, data_path in foldermap.find(data_name, **params)] dest_paths = [folder_path / data_path for data_path in data_paths] @@ -374,7 +395,9 @@ def load_interactive(): value = _interactive_request_single(node, param) description[param] = value - attributes = _interactive_request_attributes(data_struct[data_name]["attributes"]) + attributes = _interactive_request_attributes( + [attribute for attribute in data_struct[data_name]["attributes"] if attribute not in params] + ) force = input("Force download files? (Default is no) [y/N]: ") in ["y", "Y"] dest_folder = Path( input("Folder to download to? (Default is pwd, will download to /datasets subdirectory): ") @@ -390,6 +413,7 @@ def load_interactive(): if approve not in ["Y", "", "y"]: print("Aborting and not downloading!") return None + return load( data_name, attributes=attributes, folder_path=dest_folder, force=force, **description )[0] diff --git a/tests/data/data_manager/test_dataset_access.py b/tests/data/data_manager/test_dataset_access.py index 830ceddd906..5c797a63ab8 100644 --- a/tests/data/data_manager/test_dataset_access.py +++ b/tests/data/data_manager/test_dataset_access.py @@ -24,7 +24,7 @@ import pennylane as qml import pennylane.data.data_manager from pennylane.data import Dataset -from pennylane.data.data_manager import DataPath, S3_URL +from pennylane.data.data_manager import DataPath, S3_URL, _validate_attributes # pylint:disable=protected-access,redefined-outer-name @@ -376,3 +376,26 @@ def test_download_dataset_escapes_url_partial(mock_download_partial, datapath, e mock_download_partial.assert_called_once_with( f"{S3_URL}/{escaped}", dest, attributes, overwrite=force ) + + +@pytest.mark.parametrize( + "attributes,msg", + [ + ( + ["x", "y", "z", "foo"], + r"'foo' is an invalid attribute for 'my_dataset'. Valid attributes are: \['x', 'y', 'z'\]", + ), + ( + ["x", "y", "z", "foo", "bar"], + r"\['foo', 'bar'\] are invalid attributes for 'my_dataset'. Valid attributes are: \['x', 'y', 'z'\]", + ), + ], +) +def test_validate_attributes_except(attributes, msg): + """Test that ``_validate_attributes()`` raises a ValueError when passed + invalid attributes.""" + + data_struct = {"my_dataset": {"attributes": ["x", "y", "z"]}} + + with pytest.raises(ValueError, match=msg): + _validate_attributes(data_struct, "my_dataset", attributes) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 428c8db1e76..bb0f7a11584 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -198,12 +198,25 @@ def test_identifiers_base(self, identifiers, expect): assert ds.identifiers == expect + def test_identifiers_base_missing(self): + """Test that identifiers whose attribute is missing on the + dataset will not be in the returned dict.""" + ds = Dataset(x="1", identifiers=("x", "y")) + + assert ds.identifiers == {"x": "1"} + def test_subclass_identifiers(self): """Test that dataset subclasses' identifiers can be set.""" ds = MyDataset(x="1", y="2", description="abc") assert ds.identifiers == {"x": "1", "y": "2"} + def test_subclass_identifiers_missing(self): + """Test that dataset subclasses' identifiers can be set.""" + ds = MyDataset(x="1", description="abc") + + assert ds.identifiers == {"x": "1"} + def test_attribute_info(self): """Test that attribute info can be set and accessed on a dataset attribute.""" @@ -357,6 +370,23 @@ def test_write(self, tmp_path, mode): assert ds_2.bind is not ds.bind assert ds.attrs == ds_2.attrs + @pytest.mark.parametrize( + "attributes_arg,attributes_expect", + [ + (["x"], ["x", "y"]), + (["x", "y", "data"], ["x", "y", "data"]), + (["data"], ["x", "y", "data"]), + ], + ) + def test_write_partial_always_copies_identifiers(self, attributes_arg, attributes_expect): + """Test that ``write`` will always copy attributes that are identifiers.""" + ds = Dataset(x="a", y="b", data="Some data", identifiers=("x", "y")) + ds_2 = Dataset() + + ds.write(ds_2, attributes=attributes_arg) + assert set(ds_2.list_attributes()) == set(attributes_expect) + assert ds_2.identifiers == ds.identifiers + def test_init_subclass(self): """Test that __init_subclass__() does the following: