Skip to content

Commit

Permalink
added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Geeks-Sid committed Jul 13, 2023
1 parent 9a21855 commit 936fc9c
Showing 1 changed file with 31 additions and 43 deletions.
74 changes: 31 additions & 43 deletions GANDLF/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,36 @@
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from io import StringIO


def plot_all(
df_training,
df_validation,
df_testing,
output_plot_dir,
):
# Drop any columns that might have "_" in the values of their rows, this can be checked through the first row or data type
def plot_all(df_training, df_validation, df_testing, output_plot_dir):
"""
Plots training, validation, and testing data for loss and other metrics.
Args:
df_training (pd.DataFrame): DataFrame containing training data.
df_validation (pd.DataFrame): DataFrame containing validation data.
df_testing (pd.DataFrame): DataFrame containing testing data.
output_plot_dir (str): Directory to save the plots.
Returns:
tuple: Tuple containing the modified training, validation, and testing DataFrames.
"""
# Drop any columns that might have "_" in the values of their rows
banned_cols = [

Check warning on line 22 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L22

Added line #L22 was not covered by tests
col
for col in df_training.columns
if any("_" in str(val) for val in df_training[col].values)
]

# Determine metrics from the column names by removing the train_ from the column names
# Determine metrics from the column names by removing the "train_" prefix
metrics = [

Check warning on line 29 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L29

Added line #L29 was not covered by tests
col.replace("train_", "")
for col in df_training.columns
if "train_" in col and col not in banned_cols
]

print("Metrics found: ", metrics)
print("Banned columns: ", banned_cols)

# Split the values of the banned columns into multiple columns # Code for splitting output
# Split the values of the banned columns into multiple columns
# for df in [df_training, df_validation, df_testing]:
# for col in banned_cols:
# if df[col].dtype == "object":
Expand All @@ -47,10 +50,7 @@ def plot_all(
any(metric in col for col in df_training.columns) for metric in metrics
), "None of the specified metrics is in the dataframe."

required_cols = [
"epoch_no",
"train_loss",
]
required_cols = ["epoch_no", "train_loss"]

Check warning on line 53 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L53

Added line #L53 was not covered by tests

# Check if the required columns are in the dataframe
assert all(

Check warning on line 56 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L56

Added line #L56 was not covered by tests
Expand All @@ -62,29 +62,15 @@ def plot_all(
# Plot for loss
plt.figure(figsize=(12, 6))
if "train_loss" in df_training.columns:
sns.lineplot(
data=df_training,
x="epoch_no",
y="train_loss",
label="Training",
)
sns.lineplot(data=df_training, x="epoch_no", y="train_loss", label="Training")

Check warning on line 65 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L63-L65

Added lines #L63 - L65 were not covered by tests

if "valid_loss" in df_validation.columns:
sns.lineplot(

Check warning on line 68 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L67-L68

Added lines #L67 - L68 were not covered by tests
data=df_validation,
x="epoch_no",
y="valid_loss",
label="Validation",
data=df_validation, x="epoch_no", y="valid_loss", label="Validation"
)

if df_testing is not None:
if "test_loss" in df_testing.columns:
sns.lineplot(
data=df_testing,
x="epoch_no",
y="test_loss",
label="Testing",
)
if df_testing is not None and "test_loss" in df_testing.columns:
sns.lineplot(data=df_testing, x="epoch_no", y="test_loss", label="Testing")

Check warning on line 73 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L72-L73

Added lines #L72 - L73 were not covered by tests

plt.xlim(0, epochs - 1)
plt.xlabel("Epoch")
Expand Down Expand Up @@ -114,14 +100,16 @@ def plot_all(
y=metric_col.replace("train", "valid"),
label=f"Validation {metric_col}",
)
if df_testing is not None:
if metric_col.replace("train", "test") in df_testing.columns:
sns.lineplot(
data=df_testing,
x="epoch_no",
y=metric_col.replace("train", "test"),
label=f"Testing {metric_col}",
)
if (

Check warning on line 103 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L103

Added line #L103 was not covered by tests
df_testing is not None
and metric_col.replace("train", "test") in df_testing.columns
):
sns.lineplot(

Check warning on line 107 in GANDLF/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/utils/plot_utils.py#L107

Added line #L107 was not covered by tests
data=df_testing,
x="epoch_no",
y=metric_col.replace("train", "test"),
label=f"Testing {metric_col}",
)
plt.xlim(0, epochs - 1)
plt.xlabel("Epoch")
plt.ylabel(metric.capitalize())
Expand Down

0 comments on commit 936fc9c

Please sign in to comment.