Skip to content

Commit

Permalink
fix: correct threshold for scoring vss
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Sep 19, 2024
1 parent bb1153e commit a7cb0db
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions aipipeline/prediction/vss_predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def process_image_batch(batch, config_dict):
top_k = 3
logger.info(f"Processing {len(batch)} images")
project = config_dict["tator"]["project"]
vss_threshold = 1. - float(config_dict["vss"]["threshold"]) # vss returns 1 - score, so we need to invert it
vss_threshold = float(config_dict["vss"]["threshold"])
url_vs = f"{config_dict['vss']['url']}/{top_k}/{project}"
url_load = config_dict["tator"]["url_load"]
logger.debug(f"URL: {url_vs} threshold: {vss_threshold}")
Expand All @@ -65,8 +65,8 @@ def process_image_batch(batch, config_dict):

predictions = response.json()["predictions"]
scores = response.json()["scores"]
# Scores are 1 - score so we need to invert them
scores = [[1 - x for x in y] for y in scores]
# Scores are 1 - score, so we need to invert them
scores = [[1 - float(x) for x in y] for y in scores]
logger.debug(f"Predictions: {predictions}")
logger.debug(f"Scores: {scores}")

Expand All @@ -81,11 +81,11 @@ def process_image_batch(batch, config_dict):
predictions = [predictions[i:i + top_k] for i in range(0, batch_size * top_k, top_k)]
####
# Skip if rhizaria, larvacean, copepod, fecal_pellet, centric_diatom, football, or larvacean are in the predictions
low_confidence_labels = ["rhizaria", "copepod", "fecal_pellet", "centric_diatom", "football", "larvacean"]
low_confidence_labels = ["copepod"]
if not any([x in low_confidence_labels for x in predictions]):
logger.info(f"=======>Did not find {low_confidence_labels}")
return 0
# low_confidence_labels = ["rhizaria", "copepod", "fecal_pellet", "centric_diatom", "football", "larvacean"]
# low_confidence_labels = ["copepod"]
# if not any([x in low_confidence_labels for x in predictions]):
# logger.info(f"=======>Did not find {low_confidence_labels}")
# return 0

file_paths = [x[1][0] for x in files]
for i, element in enumerate(zip(scores, predictions)):
Expand All @@ -107,7 +107,7 @@ def process_image_batch(batch, config_dict):
"accept": "application/json",
"Content-Type": "application/json",
}
data = {"loc_id": database_id, "project_name": project, "dry_run": False, "score": 1 - best_score}
data = {"loc_id": database_id, "project_name": project, "dry_run": False, "score": best_score}
logger.debug(f"{url_load}/{best_pred}")
response = requests.post(f"{url_load}/{best_pred}", headers=headers, json=data)
if response.status_code == 200:
Expand Down Expand Up @@ -143,18 +143,28 @@ def run_pipeline(argv=None):
conf_files, config_dict = setup_config(args.config, silent=True)

with beam.Pipeline(options=options) as p:
image_pcoll = (
p
| "MatchFiles" >> MatchFiles(file_pattern=f"{args.image_dir}*")
)

# Apply the limit conditionally
if args.max_images:
image_pcoll = (
image_pcoll
| 'Limit Matches' >> beam.combiners.Sample.FixedSizeGlobally(int(args.max_images))
| "FlattenMatches" >> beam.FlatMap(lambda x: x)
)

(
p
| "MatchFiles" >> MatchFiles(file_pattern=f"{args.image_dir}*")
| 'Limit Matches' >> beam.combiners.Sample.FixedSizeGlobally(int(args.max_images))
| "FlattenMatches" >> beam.FlatMap(lambda x: x)
| "ReadFiles" >> ReadMatches()
| "ReadImages" >> beam.Map(read_image)
| "BatchImages" >> beam.BatchElements(min_batch_size=3, max_batch_size=3)
| "ProcessBatches" >> beam.Map(process_image_batch, config_dict)
| "SumResults" >> beam.CombineGlobally(sum)
| "WriteResults" >> beam.io.WriteToText("num_loaded")
| "LogResults" >> beam.Map(logger.info)
image_pcoll
| "ReadFiles" >> ReadMatches()
| "ReadImages" >> beam.Map(read_image)
| "BatchImages" >> beam.BatchElements(min_batch_size=3, max_batch_size=3)
| "ProcessBatches" >> beam.Map(process_image_batch, config_dict)
| "SumResults" >> beam.CombineGlobally(sum)
| "WriteResults" >> beam.io.WriteToText("num_loaded")
| "LogResults" >> beam.Map(logger.info)
)


Expand Down

0 comments on commit a7cb0db

Please sign in to comment.