Skip to content

Commit

Permalink
Update SAM automatic mask generator colab
Browse files Browse the repository at this point in the history
  • Loading branch information
healthonrails committed Jun 16, 2023
1 parent aa4c6e5 commit 24afe8e
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions docs/tutorials/automatic_mask_generator_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,6 @@
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"! wget https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.mp4"
],
"metadata": {
"id": "Q_2o3bL_W26-"
},
"id": "Q_2o3bL_W26-",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "fd2bc687",
Expand Down Expand Up @@ -824,34 +812,43 @@
{
"cell_type": "code",
"source": [
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from scipy.spatial.distance import euclidean, cosine\n",
"\n",
"def generate_mask_id(mask_features, existing_masks, threshold=0.9):\n",
"def generate_mask_id(mask_features, existing_masks, threshold=6.0, distance_metric=\"euclidean\"):\n",
" \"\"\"\n",
" Generates an ID for the mask based on its features and compares it with existing masks.\n",
"\n",
" Args:\n",
" mask_features (ndarray): The features of the mask.\n",
" existing_masks (list): List of existing masks and their features.\n",
" threshold (float): Similarity threshold for considering a match (default: 0.9).\n",
" distance_metric (str): Distance metric to be used (default: \"euclidean\").\n",
" Options: \"euclidean\", \"cosine\".\n",
"\n",
" Returns:\n",
" mask_id (int): The generated ID for the mask.\n",
" \"\"\"\n",
" mask_id = -1 # Initialize the mask ID\n",
"\n",
" if distance_metric == \"euclidean\":\n",
" distance_function = euclidean\n",
" elif distance_metric == \"cosine\":\n",
" distance_function = cosine\n",
" else:\n",
" raise ValueError(\"Invalid distance metric. Choose either 'euclidean' or 'cosine'.\")\n",
"\n",
" for idx, (existing_id, existing_features) in enumerate(existing_masks):\n",
" similarity = cosine_similarity(mask_features, existing_features)\n",
" similarity = distance_function(mask_features.flatten(), existing_features.flatten())\n",
"\n",
" if similarity > threshold:\n",
" if similarity < threshold:\n",
" mask_id = existing_id\n",
" break\n",
"\n",
" if mask_id == -1:\n",
" mask_id = len(existing_masks) + 1 # Assign a new ID if no match is found\n",
" existing_masks.append((mask_id, mask_features))\n",
" existing_masks.append((mask_id, mask_features.flatten()))\n",
"\n",
" return mask_id"
" return mask_id\n"
],
"metadata": {
"id": "pVufxTIGwHrg"
Expand All @@ -863,7 +860,6 @@
{
"cell_type": "code",
"source": [
"from math import inf\n",
"import pycocotools.mask as mask_util\n",
"def convert_to_annolid_format(frame_number, \n",
" masks,\n",
Expand Down

0 comments on commit 24afe8e

Please sign in to comment.