Skip to content

Commit

Permalink
Segmonitor build
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 14, 2023
1 parent 253e01b commit 4f1ee97
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 135 deletions.
8 changes: 8 additions & 0 deletions config/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"model_type": "MTLSD",
"iterations": 100000,
"warmup": 100000,
"raw_file": "path/to/zarr/or/n5",
"voxel_size": 33,
"python_script_path": "path/to/python_script.py"
}
2 changes: 0 additions & 2 deletions optoseg/README.md

This file was deleted.

3 changes: 0 additions & 3 deletions optoseg/go.mod

This file was deleted.

39 changes: 0 additions & 39 deletions optoseg/src/jobs/job.go

This file was deleted.

17 changes: 0 additions & 17 deletions optoseg/src/main.go

This file was deleted.

11 changes: 11 additions & 0 deletions segmonitor/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "segmonitor"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
mongodb = "2.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
55 changes: 55 additions & 0 deletions segmonitor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::process::Command;
use serde::{Deserialize, Serialize};
use mongodb::{Client, options::ClientOptions};

#[derive(Debug, Deserialize, Serialize)]
struct Config {
model_type: String,
iterations: i32,
warmup: i32,
raw_file: String,
voxel_size: i32,
python_script_path: String,
}

pub mod segmonitor {
pub fn train_model_from_config(config_path: &str) {
let config = load_config(config_path);

println!("Training model: {}", config.model_type);
println!("Iterations: {}", config.iterations);
println!("Warmup: {}", config.warmup);
println!("Raw file: {}", config.raw_file);
println!("Voxel size: {}", config.voxel_size);

call_python_train(&config.python_script_path);

save_to_mongodb(&config);
}

fn load_config(config_path: &str) -> Config {
// Load configuration from JSON file
let config_str = std::fs::read_to_string(config_path).expect("Error reading config file");
serde_json::from_str(&config_str).expect("Error parsing JSON")
}

fn call_python_train(script_path: &str) {
let output = Command::new("python")
.arg(script_path)
.output()
.expect("Failed to execute training");

if output.status.success() {
println!("Traning executed successfully!");
} else {
println!("Error executing training:\n{}", String::from_utf8_lossy(&output.stderr));
}
}

fn save_to_mongodb(config: &Config) {
println!("Saving metrics to MongoDB...");
let client_options = ClientOptions::parse("mongodb://localhost:27017").unwrap();
let client = Client::with_options(client_options).unwrap();
// TODO: dd MongoDB insertion logic here
}
}
6 changes: 6 additions & 0 deletions segmonitor/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use segmonitor;

fn main() {
let config_path = "path/to/config.json";
segmonitor::train_model_from_config(config_path);
}
25 changes: 13 additions & 12 deletions src/autoseg/eval/eval_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sqlite3
import json


class Database:
"""
Simple SQLite Database Wrapper for Storing and Retrieving Scores.
Expand All @@ -9,17 +10,17 @@ class Database:
Each score entry is associated with a network, checkpoint, threshold, and a dictionary of scores.
Args:
db_name (str):
db_name (str):
The name of the SQLite database file.
table_name (str):
table_name (str):
The name of the table within the database (default is 'scores_table').
Attributes:
conn (sqlite3.Connection):
conn (sqlite3.Connection):
The SQLite database connection.
cursor (sqlite3.Cursor):
cursor (sqlite3.Cursor):
The SQLite database cursor.
table_name (str):
table_name (str):
The name of the table within the database.
Methods:
Expand Down Expand Up @@ -50,13 +51,13 @@ def add_score(self, network, checkpoint, threshold, scores_dict):
Add a score entry to the database.
Args:
network (str):
network (str):
The name of the network.
checkpoint (int):
checkpoint (int):
The checkpoint number.
threshold (float):
threshold (float):
The threshold value.
scores_dict (dict):
scores_dict (dict):
A dictionary containing scores.
"""
assert type(network) is str
Expand All @@ -74,11 +75,11 @@ def get_scores(self, networks=None, checkpoints=None, thresholds=None):
Retrieve scores from the database based on specified conditions.
Args:
networks (str, list):
networks (str, list):
The name or list of names of networks to filter on.
checkpoints (int, list):
checkpoints (int, list):
The checkpoint number or list of checkpoint numbers to filter on.
thresholds (float, list):
thresholds (float, list):
The threshold value or list of threshold values to filter on.
Returns:
Expand Down
30 changes: 15 additions & 15 deletions src/autoseg/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def segment_and_validate(
It logs information about the segmentation and validation process.
Args:
model_checkpoint (str):
model_checkpoint (str):
The checkpoint of the segmentation model to use (default is "latest").
checkpoint_num (int):
checkpoint_num (int):
The checkpoint number for the affinity model (default is 250000).
setup_num (str):
setup_num (str):
The setup number for the affinity model (default is "1738").
Returns:
dict:
dict:
A dictionary containing scores and evaluation metrics.
"""
logger.info(
Expand Down Expand Up @@ -85,27 +85,27 @@ def validate(
Validate segmentation results using specified parameters.
Args:
checkpoint (str):
checkpoint (str):
The checkpoint identifier.
threshold (float):
threshold (float):
The threshold value.
offset (str):
offset (str):
The offset for ROI (default is "3960,3960,3960").
roi_shape (str):
roi_shape (str):
The shape of ROI (default is "31680,31680,31680").
skel (str):
skel (str):
The path to the skeleton data file (default is "../../data/XPRESS_validation_skels.npz").
zarr (str):
zarr (str):
The path to the Zarr file for storing segmentation data (default is "./validation.zarr").
h5 (str):
h5 (str):
The path to the HDF5 file for storing validation data (default is "validation.h5").
ds (str):
ds (str):
The dataset name (default is "pred_seg").
print_errors (bool):
print_errors (bool):
Print errors during validation (default is False).
print_in_xyz (bool):
print_in_xyz (bool):
Print coordinates in XYZ format (default is False).
downsample (int):
downsample (int):
Downsample factor for evaluation (default is None).
"""
network = os.path.abspath(".").split(os.path.sep)[-1]
Expand Down
Loading

0 comments on commit 4f1ee97

Please sign in to comment.