Skip to content

Commit

Permalink
fix: removed any augmented data for init vss and other minor improvem…
Browse files Browse the repository at this point in the history
…ents to plot names
  • Loading branch information
danellecline committed Aug 30, 2024
1 parent 14bf018 commit 2f040d0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
5 changes: 3 additions & 2 deletions aipipeline/metrics/calc_accuracy_vss.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ def calc_accuracy(config: dict, image_dir: str, password: str):
plt.suptitle(f"CM {project} exemplars. Top-1 Accuracy: {accuracy_top1:.2f}, Top-3 Accuracy: {accuracy_top3:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}")
d = f"{datetime.now():%Y-%m-%d %H:%M:%S}"
plt.title(d)
plt.savefig(f"confusion_matrix_{project}_{d}.png")
plot_name = f"confusion_matrix_{project}_{datetime.now():%Y-%m-%d %H%M%S}.png"
logger.info(f"Saving confusion matrix to {plot_name}")
plt.savefig(plot_name)
plt.close()
logger.info(f"Confusion matrix saved to confusion_matrix_{project}_{d}.png")


def main(argv=None):
Expand Down
13 changes: 8 additions & 5 deletions aipipeline/metrics/plot_tsne_vss.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def plot_tsne(config: dict, password: str):
vectors_2d = tsne.fit_transform(v)
logging.info(f"t-SNE completed on {len(v)} vectors")

# Get the width of the vectors
vector_width = len(v[0]) - 1

# Plot the t-SNE results, colored by class
plt.figure(figsize=(12, 12))

Expand Down Expand Up @@ -135,18 +138,18 @@ def plot_tsne(config: dict, password: str):

# Customize and place the legend outside the plot
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), title="Classes")
plt.suptitle(f"t-SNE of 768-dimensional vectors {project} exemplars")
d = f"{datetime.now():%Y-%m-%d %H:%M:%S}"
plt.title(d)
plt.suptitle(f"t-SNE of {vector_width}-dimensional vectors {project} exemplars")
plt.title(f"{datetime.now():%Y-%m-%d %H:%M:%S}")

# Plot the t-SNE results, colored by class
for i, class_name in enumerate(class_names):
x, y = vectors_2d[i]
idx = np.where(np.unique(class_names) == class_name)[0][0]
plt.scatter(x, y, c=colors[idx], label=class_name, s=10, alpha=0.5)

logging.info(f"Saving plot to tsne_plot_{project}_{d}.png")
plt.savefig(f"tsne_plot_{project}_{d}.png")
plot_name = f"tsne_plot_{project}_{datetime.now():%Y-%m-%d %H%M%S}.png"
logging.info(f"Saving plot to {plot_name}")
plt.savefig(plot_name)
plt.show()
except Exception as e:
logging.exception(f"Error: {e}")
Expand Down
8 changes: 8 additions & 0 deletions aipipeline/prediction/vss_init_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime

import dotenv
import glob
import os
from pathlib import Path
import apache_beam as beam
Expand Down Expand Up @@ -138,6 +139,13 @@ def run_pipeline(argv=None):
if not args.skip_clean:
clean(base_path)

# Always remove any previous augmented data before starting
logger.info("Removing any previous augmented data")
pattern = os.path.join(processed_data, '*.*.png')
files = glob.glob(pattern)
for file in files:
os.remove(file)

with beam.Pipeline(options=options) as p:
(
p
Expand Down

0 comments on commit 2f040d0

Please sign in to comment.