Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of metadata suggestion endpoint #1403

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions nmdc_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import json
import logging
from io import BytesIO, StringIO
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

import requests
from fastapi import APIRouter, Depends, Header, HTTPException, Response, status
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, status
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
Expand All @@ -18,6 +18,7 @@
from nmdc_server.data_object_filters import WorkflowActivityTypeEnum
from nmdc_server.database import get_db
from nmdc_server.ingest.envo import nested_envo_trees
from nmdc_server.metadata import SampleMetadataSuggester
from nmdc_server.models import (
IngestLock,
SubmissionEditorRole,
Expand Down Expand Up @@ -1040,6 +1041,36 @@ async def submit_metadata(
return submission


@router.post(
"/metadata_submission/suggest",
tags=["metadata_submission"],
responses=login_required_responses,
)
async def suggest_metadata(
body: List[schemas_submission.MetadataSuggestionRequest],
suggester: SampleMetadataSuggester = Depends(SampleMetadataSuggester),
types: Union[List[schemas_submission.MetadataSuggestionType], None] = Query(None),
user: models.User = Depends(get_current_user),
) -> List[schemas_submission.MetadataSuggestion]:
response: List[schemas_submission.MetadataSuggestion] = []
for item in body:
suggestions = suggester.get_suggestions(item.data, types=types)
for slot, value in suggestions.items():
response.append(
schemas_submission.MetadataSuggestion(
type=(
schemas_submission.MetadataSuggestionType.REPLACE
if slot in item.data
else schemas_submission.MetadataSuggestionType.ADD
),
row=item.row,
slot=slot,
value=value,
)
)
return response


@router.get(
"/users", responses=login_required_responses, response_model=query.UserResponse, tags=["user"]
)
Expand Down
71 changes: 71 additions & 0 deletions nmdc_server/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import re
from typing import Any, Callable, Dict, List, Optional

from nmdc_geoloc_tools import GeoEngine

from nmdc_server.schemas_submission import MetadataSuggestionType


class SampleMetadataSuggester:
"""A class to suggest sample metadata values based on partial sample metadata."""

def __init__(self):
self._geo_engine: Optional[GeoEngine] = None

@property
def geo_engine(self) -> GeoEngine:
"""A GeoEngine instance for looking up geospatial data."""
if self._geo_engine is None:
self._geo_engine = GeoEngine()
return self._geo_engine

def suggest_elevation_from_lat_lon(self, sample: Dict[str, str]) -> Optional[float]:
"""Suggest an elevation for a sample based on its lat_lon."""
lat_lon = sample.get("lat_lon", None)
if lat_lon is None:
return None
lat_lon_split = re.split("[, ]+", lat_lon)
if len(lat_lon_split) == 2:
try:
lat, lon = map(float, lat_lon_split)
return self.geo_engine.get_elevation((lat, lon))
except ValueError:
# This could happen if the lat_lon string is not parseable as a float
# or the GeoEngine determined they are invalid values. In either case,
# just don't suggest an elevation.
pass
return None

def get_suggestions(
self, sample: Dict[str, str], *, types: Optional[List[MetadataSuggestionType]] = None
) -> Dict[str, str]:
"""Suggest metadata values for a sample.

Returns a dictionary where the keys are sample metadata slots and the values are suggested
values.
"""

# Not explicitly supplying types implies using all types.
if types is None:
types = list(MetadataSuggestionType)

do_add = MetadataSuggestionType.ADD in types
do_replace = MetadataSuggestionType.REPLACE in types

# Map from sample metadata slot to a list of functions that can suggest values for
# that slot.
suggesters: dict[str, list[Callable[[dict[str, str]], Optional[Any]]]] = {
"elev": [self.suggest_elevation_from_lat_lon],
}

suggestions = {}

for target_slot, suggester_list in suggesters.items():
has_data = target_slot in sample and sample[target_slot]
if (do_add and not has_data) or (do_replace and has_data):
for suggester_fn in suggester_list:
suggestion = suggester_fn(sample)
if suggestion is not None:
suggestions[target_slot] = str(suggestion)

return suggestions
18 changes: 18 additions & 0 deletions nmdc_server/schemas_submission.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID

Expand Down Expand Up @@ -144,3 +145,20 @@ def populate_roles(cls, metadata_submission, values):


SubmissionMetadataSchema.update_forward_refs()


class MetadataSuggestionRequest(BaseModel):
row: int
data: Dict[str, str]


class MetadataSuggestionType(str, Enum):
ADD = "add"
REPLACE = "replace"


class MetadataSuggestion(BaseModel):
type: MetadataSuggestionType
row: int
slot: str
value: str
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"ipython==8.10.0",
"itsdangerous==2.0.1",
"mypy<0.920",
"nmdc-geoloc-tools==0.1.1",
"nmdc-schema==10.8.0",
"nmdc-submission-schema==10.8.0",
"pint==0.18",
Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from factory import random
from nmdc_geoloc_tools import GeoEngine
from starlette.testclient import TestClient

from nmdc_server import database, schemas
Expand All @@ -17,6 +18,27 @@ def set_seed(connection):
random.reseed_random("nmdc")


@pytest.fixture(autouse=True)
def patch_geo_engine(monkeypatch):
"""Patch all the GeoEngine methods that make external network requests."""

def mock_get_elevation(self, lat_lon):
lat, lon = lat_lon
if not -90 <= lat <= 90:
raise ValueError(f"Invalid Latitude: {lat}")
if not -180 <= lon <= 180:
raise ValueError(f"Invalid Longitude: {lon}")
return 16.0

def mock_not_implemented(self, *args, **kwargs):
raise NotImplementedError()

monkeypatch.setattr(GeoEngine, "get_elevation", mock_get_elevation)
monkeypatch.setattr(GeoEngine, "get_fao_soil_type", mock_not_implemented)
monkeypatch.setattr(GeoEngine, "get_landuse", mock_not_implemented)
monkeypatch.setattr(GeoEngine, "get_landuse_dates", mock_not_implemented)


@pytest.fixture(scope="session")
def connection():
assert settings.environment == "testing"
Expand Down
33 changes: 33 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from nmdc_server.metadata import SampleMetadataSuggester


def test_sample_metadata_suggester_elevation():
suggester = SampleMetadataSuggester()

# Test with valid lat_lon
sample = {"lat_lon": "37.875766 -122.248580"}
elevation = suggester.suggest_elevation_from_lat_lon(sample)
assert elevation == 16.0

# Be tolerant of a comma separator
sample = {"lat_lon": "37.875766, -122.248580"}
elevation = suggester.suggest_elevation_from_lat_lon(sample)
assert elevation == 16.0

# Don't return a suggestion when lat_lon is missing
sample = {}
elevation = suggester.suggest_elevation_from_lat_lon(sample)
assert elevation is None

# Don't return a suggestion when lat_lon is invalid
sample = {"lat_lon": "91.0 -122.248580"}
elevation = suggester.suggest_elevation_from_lat_lon(sample)
assert elevation is None

sample = {"lat_lon": "no good"}
elevation = suggester.suggest_elevation_from_lat_lon(sample)
assert elevation is None

sample = {"lat_lon": "0 0 0"}
elevation = suggester.suggest_elevation_from_lat_lon(sample)
assert elevation is None
55 changes: 55 additions & 0 deletions tests/test_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
from nmdc_server.schemas_submission import SubmissionMetadataSchema, SubmissionMetadataSchemaPatch


@pytest.fixture
def suggest_payload():
return [
{"row": 1, "data": {"foo": "bar", "lat_lon": "44.058648, -123.095277"}},
{"row": 3, "data": {"elev": 0, "lat_lon": "44.046389 -123.051910"}},
{"row": 4, "data": {"foo": "bar"}},
{"row": 5, "data": {"lat_lon": "garbage foo bar"}},
]


def test_list_submissions(db: Session, client: TestClient, logged_in_user):
submission = fakes.MetadataSubmissionFactory(
author=logged_in_user, author_orcid=logged_in_user.orcid
Expand Down Expand Up @@ -610,3 +620,48 @@ def test_sync_submission_study_name(db: Session, client: TestClient, logged_in_u
response = client.request(method="GET", url=f"/api/metadata_submission/{submission.id}")
assert response.status_code == 200
assert response.json()["study_name"] == expected_val


def test_metadata_suggest(client: TestClient, suggest_payload, logged_in_user):
response = client.request(
method="POST", url="/api/metadata_submission/suggest", json=suggest_payload
)
assert response.status_code == 200
assert response.json() == [
{"type": "add", "row": 1, "slot": "elev", "value": "16.0"},
{"type": "replace", "row": 3, "slot": "elev", "value": "16.0"},
]


def test_metadata_suggest_single_type(client: TestClient, suggest_payload, logged_in_user):
response = client.request(
method="POST",
url="/api/metadata_submission/suggest?types=add",
json=suggest_payload,
)
assert response.status_code == 200
assert response.json() == [
{"type": "add", "row": 1, "slot": "elev", "value": "16.0"},
]


def test_metadata_suggest_multiple_types(client: TestClient, suggest_payload, logged_in_user):
response = client.request(
method="POST",
url="/api/metadata_submission/suggest?types=add&types=replace",
json=suggest_payload,
)
assert response.status_code == 200
assert response.json() == [
{"type": "add", "row": 1, "slot": "elev", "value": "16.0"},
{"type": "replace", "row": 3, "slot": "elev", "value": "16.0"},
]


def test_metadata_suggest_invalid_type(client: TestClient, suggest_payload, logged_in_user):
response = client.request(
method="POST",
url="/api/metadata_submission/suggest?types=whatever",
json=suggest_payload,
)
assert response.status_code == 422