From 518c18bc52d544dcacc89a66dc99e5589de5f117 Mon Sep 17 00:00:00 2001 From: Sudarshan Vijay <47235474+sudarshanv01@users.noreply.github.com> Date: Thu, 3 Aug 2023 02:33:49 -0700 Subject: [PATCH] Feat: Add to_csv and to_frame (#103) * Convert a Series to a frame through the to_frame method * Implement to_frame for sequence of series * Alter to_frame to allow it to take in multiple y-inputs * Add width to series specifically for the to_frame and to_csv methods * Allow different length series to to_frame * Add to_csv and to_frame to the Mixin class * Added documentation for the to_frame and to_csv methods * Assert in to_frame that the dimension of weight and y is the same --- src/py4vasp/_third_party/graph/graph.py | 56 ++++++++++++++++++++ src/py4vasp/_third_party/graph/mixin.py | 50 +++++++++++++++++- src/py4vasp/_third_party/graph/series.py | 4 +- tests/data/conftest.py | 2 + tests/third_party/graph/test_graph.py | 67 ++++++++++++++++++++++++ tests/third_party/graph/test_mixin.py | 44 ++++++++++++++++ 6 files changed, 221 insertions(+), 2 deletions(-) diff --git a/src/py4vasp/_third_party/graph/graph.py b/src/py4vasp/_third_party/graph/graph.py index 5c3f7bc8..6824658e 100644 --- a/src/py4vasp/_third_party/graph/graph.py +++ b/src/py4vasp/_third_party/graph/graph.py @@ -1,6 +1,7 @@ # Copyright © VASP Software GmbH, # Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import itertools +import uuid from collections.abc import Sequence from dataclasses import dataclass, fields, replace @@ -13,6 +14,7 @@ go = import_.optional("plotly.graph_objects") subplots = import_.optional("plotly.subplots") +pd = import_.optional("pandas") @dataclass @@ -155,6 +157,60 @@ def _set_yaxis_options(self, figure): if self.y2label: figure.layout.yaxis2.title.text = self.y2label + def to_frame(self): + """Convert graph to a pandas dataframe. + + Every series will have at least two columns, named after the series name + with the suffix x and y. Additionally, if weights are provided, they will + also be written out as another column. If a series does not have a name, a + name will be generated based on a uuid. + + Returns + ------- + Dataframe + A pandas dataframe with columns for each series in the graph + """ + df = pd.DataFrame() + for series in np.atleast_1d(self.series): + _df = self._create_and_populate_df(series) + df = df.join(_df, how="outer") + return df + + def to_csv(self, filename): + """Export graph to a csv file. + + Starting from the dataframe generated from `to_frame`, use the `to_csv` method + implemented in pandas to write out a csv file with a given filename + + Parameters + ---------- + filename: str | Path + Name of the exported csv file + """ + df = self.to_frame() + df.to_csv(filename, index=False) + + def _create_and_populate_df(self, series): + df = pd.DataFrame() + df[self._name_column(series, "x", None)] = series.x + for idx, series_y in enumerate(np.atleast_2d(series.y)): + df[self._name_column(series, "y", idx)] = series_y + if series.width is not None: + assert series.width.ndim == series.y.ndim + for idx, series_width in enumerate(np.atleast_2d(series.width)): + df[self._name_column(series, "width", idx)] = series_width + return df + + def _name_column(self, series, suffix, idx=None): + if series.name: + text_suffix = series.name.replace(" ", "_") + f".{suffix}" + else: + text_suffix = "series_" + str(uuid.uuid1()) + if series.y.ndim == 1 or idx is None: + return text_suffix + else: + return f"{text_suffix}{idx}" + @property def _subplot_on(self): return any(series.subplot for series in self) diff --git a/src/py4vasp/_third_party/graph/mixin.py b/src/py4vasp/_third_party/graph/mixin.py index 8bf93fa4..59c018f6 100644 --- a/src/py4vasp/_third_party/graph/mixin.py +++ b/src/py4vasp/_third_party/graph/mixin.py @@ -1,6 +1,7 @@ # Copyright © VASP Software GmbH, # Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import abc +import os from py4vasp._third_party.graph.graph import Graph from py4vasp._util import convert @@ -15,6 +16,49 @@ class Mixin(abc.ABC): def to_graph(self, *args, **kwargs): pass + def to_frame(self, *args, **kwargs): + """Wrapper around the :py:meth:`to_frame` function. + + Generates dataframes from the graph object. For information about + parameters that can be passed to this method, look at :py:meth:`to_graph`. + + Returns + ------- + Dataframe + Pandas dataframe corresponding to data in the graph + """ + graph = self.to_graph(*args, **kwargs) + return graph.to_frame() + + def to_csv(self, *args, filename=None, **kwargs): + """Converts data to a csv file. + + Writes out a csv file for data stored in a dataframe generated with + the :py:meth:`to_frame` method. Useful for creating external plots + for further analysis. + + If no filename is provided a default filename is deduced from the + name of the class. + + Note that the filename must be a keyword argument, i.e., you explicitly + need to write *filename="name_of_file"* because the arguments are passed + on to the :py:meth:`to_graph` function. Please check the documentation of that function + to learn which arguments are allowed. + + Parameters + ---------- + filename: str | Path + Name of the csv file which the data is exported to. + """ + classname = convert.to_snakecase(self.__class__.__name__).strip("_") + filename = filename if filename is not None else f"{classname}.csv" + if os.path.isabs(filename): + writeout_path = filename + else: + writeout_path = self._path / filename + df = self.to_frame(*args, **kwargs) + df.to_csv(writeout_path, index=False) + def plot(self, *args, **kwargs): """Wrapper around the :py:meth:`to_graph` function. @@ -49,7 +93,11 @@ def to_image(self, *args, filename=None, **kwargs): fig = self.to_plotly(*args, **kwargs) classname = convert.to_snakecase(self.__class__.__name__).strip("_") filename = filename if filename is not None else f"{classname}.png" - fig.write_image(self._path / filename) + if os.path.isabs(filename): + writeout_path = filename + else: + writeout_path = self._path / filename + fig.write_image(writeout_path) def _merge_graphs(graphs): diff --git a/src/py4vasp/_third_party/graph/series.py b/src/py4vasp/_third_party/graph/series.py index 946051f3..802af191 100644 --- a/src/py4vasp/_third_party/graph/series.py +++ b/src/py4vasp/_third_party/graph/series.py @@ -39,7 +39,9 @@ class Series: _frozen = False def __post_init__(self): - if len(self.x) != np.array(self.y).shape[-1]: + self.x = np.asarray(self.x) + self.y = np.asarray(self.y) + if len(self.x) != self.y.shape[-1]: message = "The length of the two plotted components is inconsistent." raise exception.IncorrectUsage(message) if self.width is not None and len(self.x) != self.width.shape[-1]: diff --git a/tests/data/conftest.py b/tests/data/conftest.py index 0265ed0c..c310e2e0 100644 --- a/tests/data/conftest.py +++ b/tests/data/conftest.py @@ -57,6 +57,8 @@ def should_test_method(name): return False if name == "to_image": # would have side effects return False + if name == "to_csv": + return False return True diff --git a/tests/third_party/graph/test_graph.py b/tests/third_party/graph/test_graph.py index 1663d3ce..26d05829 100644 --- a/tests/third_party/graph/test_graph.py +++ b/tests/third_party/graph/test_graph.py @@ -286,6 +286,73 @@ def test_add_label_to_multiple_lines(parabola, sine, Assert): assert graph.series[1].name == "new label sine" +def test_convert_parabola_to_frame(parabola, Assert, not_core): + graph = Graph(parabola) + df = graph.to_frame() + Assert.allclose(df["parabola.x"], parabola.x) + Assert.allclose(df["parabola.y"], parabola.y) + + +def test_convert_sequence_parabola_to_frame(parabola, sine, Assert, not_core): + sequence = [parabola, sine] + graph = Graph(sequence) + df = graph.to_frame() + Assert.allclose(df["parabola.x"], parabola.x) + Assert.allclose(df["parabola.y"], parabola.y) + Assert.allclose(df["sine.x"], sine.x) + Assert.allclose(df["sine.y"], sine.y) + + +def test_convert_multiple_lines(two_lines, Assert, not_core): + graph = Graph(two_lines) + df = graph.to_frame() + assert len(df.columns) == 3 + Assert.allclose(df["two_lines.x"], two_lines.x) + Assert.allclose(df["two_lines.y0"], two_lines.y[0]) + Assert.allclose(df["two_lines.y1"], two_lines.y[1]) + + +def test_convert_two_fatbands_to_frame(two_fatbands, Assert, not_core): + graph = Graph(two_fatbands) + df = graph.to_frame() + Assert.allclose(df["two_fatbands.x"], two_fatbands.x) + Assert.allclose(df["two_fatbands.y0"], two_fatbands.y[0]) + Assert.allclose(df["two_fatbands.y1"], two_fatbands.y[1]) + Assert.allclose(df["two_fatbands.width0"], two_fatbands.width[0]) + Assert.allclose(df["two_fatbands.width1"], two_fatbands.width[1]) + + +def test_write_csv(tmp_path, two_fatbands, non_numpy, Assert, not_core): + import pandas as pd + + sequence = [two_fatbands, *non_numpy] + graph = Graph(sequence) + graph.to_csv(tmp_path / "filename.csv") + ref = graph.to_frame() + actual = pd.read_csv(tmp_path / "filename.csv") + ref_rounded = np.round(ref.values, 12) + actual_rounded = np.round(actual.values, 12) + Assert.allclose(ref_rounded, actual_rounded) + + +def test_convert_different_length_series_to_frame( + parabola, two_lines, Assert, not_core +): + sequence = [two_lines, parabola] + graph = Graph(sequence) + df = graph.to_frame() + assert len(df) == max(len(parabola.x), len(two_lines.x)) + Assert.allclose(df["parabola.x"], parabola.x) + Assert.allclose(df["parabola.y"], parabola.y) + pad_width = len(parabola.x) - len(two_lines.x) + pad_nan = np.repeat(np.nan, pad_width) + padded_two_lines_x = np.hstack((two_lines.x, pad_nan)) + padded_two_lines_y = np.hstack((two_lines.y, np.vstack((pad_nan, pad_nan)))) + Assert.allclose(df["two_lines.x"], padded_two_lines_x) + Assert.allclose(df["two_lines.y0"], padded_two_lines_y[0]) + Assert.allclose(df["two_lines.y1"], padded_two_lines_y[1]) + + @patch("plotly.graph_objs.Figure._ipython_display_") def test_ipython_display(mock_display, parabola, not_core): graph = Graph(parabola) diff --git a/tests/third_party/graph/test_mixin.py b/tests/third_party/graph/test_mixin.py index bf61d21e..331fefe8 100644 --- a/tests/third_party/graph/test_mixin.py +++ b/tests/third_party/graph/test_mixin.py @@ -35,6 +35,22 @@ def test_converting_graph_to_plotly(): assert fig == GRAPH.to_plotly.return_value +def test_convert_graph_to_frame(): + example = ExampleGraph() + df = example.to_frame() + GRAPH.to_frame.assert_called_once_with() + assert df == GRAPH.to_frame.return_value + + +def test_convert_graph_to_csv(): + example = ExampleGraph() + example.to_csv() + GRAPH.to_frame.assert_called_once_with() + full_path = example._path / "example_graph.csv" + df = GRAPH.to_frame.return_value + df.to_csv.assert_called_once_with(full_path, index=False) + + def test_converting_graph_to_image(): example = ExampleGraph() example.to_image() @@ -42,6 +58,25 @@ def test_converting_graph_to_image(): fig.write_image.assert_called_once_with(example._path / "example_graph.png") +def test_converting_graph_to_csv_with_relative_filename(): + example = ExampleGraph() + example.to_csv(filename="example.csv") + full_path = example._path / "example.csv" + GRAPH.to_frame.assert_called_once_with() + df = GRAPH.to_frame.return_value + df.to_csv.assert_called_once_with(full_path, index=False) + + +def test_converting_graph_to_csv_with_absolute_filename(): + example = ExampleGraph() + basedir_path = example._path.absolute() + full_path = basedir_path / "example.csv" + example.to_csv(filename=full_path) + GRAPH.to_frame.assert_called_once_with() + df = GRAPH.to_frame.return_value + df.to_csv.assert_called_once_with(full_path, index=False) + + def test_converting_graph_to_image_with_filename(): example = ExampleGraph() example.to_image(filename="example.jpg") @@ -49,6 +84,15 @@ def test_converting_graph_to_image_with_filename(): fig.write_image.assert_called_once_with(example._path / "example.jpg") +def test_converting_graph_to_image_with_absolute_filename(): + example = ExampleGraph() + basedir_path = example._path.absolute() + full_path = basedir_path / "example.jpg" + example.to_image(filename=full_path) + fig = GRAPH.to_plotly.return_value + fig.write_image.assert_called_once_with(full_path) + + def test_filename_is_keyword_only_argument(): example = ExampleGraph() with pytest.raises(TypeError):