Skip to content

Commit

Permalink
Refactored association sorting to common util
Browse files Browse the repository at this point in the history
  • Loading branch information
VukW committed Aug 10, 2024
1 parent 039f496 commit c225a5e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 53 deletions.
20 changes: 3 additions & 17 deletions cli/medperf/web_ui/benchmarks/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from medperf.entities.cube import Cube
from medperf.account_management import get_medperf_user_data
from medperf.enums import Status
from medperf.web_ui.common import templates
from medperf.web_ui.common import templates, sort_associations_display

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -37,23 +37,9 @@ def benchmark_detail_ui(request: Request, benchmark_id: int):
datasets_associations = Benchmark.get_datasets_associations(benchmark_uid=benchmark_id)
models_associations = Benchmark.get_models_associations(benchmark_uid=benchmark_id)

approval_status_order = {
Status.PENDING: 0,
Status.APPROVED: 1,
Status.REJECTED: 2,
}
datasets_associations = sort_associations_display(datasets_associations)
models_associations = sort_associations_display(models_associations)

def assoc_sorting_key(assoc):
# lower status - first
status_order = approval_status_order.get(assoc.approval_status, -1)
# recent associations - first
date_order = -(assoc.approved_at or assoc.created_at).timestamp()
return status_order, date_order

datasets_associations = sorted(datasets_associations, key=assoc_sorting_key)
models_associations = sorted(models_associations, key=assoc_sorting_key)

# Fetch datasets and models information
datasets = {assoc.dataset: Dataset.get(assoc.dataset) for assoc in datasets_associations if assoc.dataset}
models = {assoc.model_mlcube: Cube.get(assoc.model_mlcube) for assoc in models_associations if assoc.model_mlcube}

Expand Down
31 changes: 30 additions & 1 deletion cli/medperf/web_ui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from fastapi.requests import Request

from medperf.entities.association import Association
from medperf.enums import Status

templates = Jinja2Templates(directory=str(resources.path("medperf.web_ui", "templates")))

logger = logging.getLogger(__name__)
Expand All @@ -18,4 +21,30 @@ async def custom_exception_handler(request: Request, exc: Exception):
context = {"request": request, "exception": exc}

# Return a detailed error page
return templates.TemplateResponse("error.html", context, status_code=500)
return templates.TemplateResponse("error.html", context, status_code=500)


def sort_associations_display(associations: list[Association]) -> list[Association]:
"""
Sorts associations:
- by approval status (pending, approved, rejected)
- by date (recent first)
Args:
associations: associations to sort
Returns: sorted list
"""

approval_status_order = {
Status.PENDING: 0,
Status.APPROVED: 1,
Status.REJECTED: 2,
}

def assoc_sorting_key(assoc):
# lower status - first
status_order = approval_status_order.get(assoc.approval_status, -1)
# recent associations - first
date_order = -(assoc.approved_at or assoc.created_at).timestamp()
return status_order, date_order

return sorted(associations, key=assoc_sorting_key)
20 changes: 2 additions & 18 deletions cli/medperf/web_ui/datasets/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from medperf.entities.dataset import Dataset
from medperf.entities.benchmark import Benchmark
from medperf.enums import Status
from medperf.web_ui.common import templates
from medperf.web_ui.common import templates, sort_associations_display

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,25 +43,9 @@ def dataset_detail_ui(request: Request, dataset_id: int):
prep_cube = Cube.get(cube_uid=dataset.data_preparation_mlcube)
prep_cube_name = prep_cube.name if prep_cube else "Unknown"

# Fetching associations and related benchmarks
benchmark_associations = Dataset.get_benchmarks_associations(dataset_uid=dataset_id)
benchmark_associations = sort_associations_display(benchmark_associations)

approval_status_order = {
Status.PENDING: 0,
Status.APPROVED: 1,
Status.REJECTED: 2,
}

def assoc_sorting_key(assoc):
# lower status - first
status_order = approval_status_order.get(assoc.approval_status, -1)
# recent associations - first
date_order = -(assoc.approved_at or assoc.created_at).timestamp()
return status_order, date_order

benchmark_associations = sorted(benchmark_associations, key=assoc_sorting_key)

# Fetch benchmarks information
benchmarks = {assoc.benchmark: Benchmark.get(assoc.benchmark) for assoc in benchmark_associations if
assoc.benchmark}

Expand Down
19 changes: 2 additions & 17 deletions cli/medperf/web_ui/mlcubes/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from medperf.entities.cube import Cube
from medperf.entities.benchmark import Benchmark
from medperf.enums import Status
from medperf.web_ui.common import templates
from medperf.web_ui.common import templates, sort_associations_display

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,23 +38,8 @@ def mlcubes_ui(request: Request, local_only: bool = False, mine_only: bool = Fal
def mlcube_detail_ui(request: Request, mlcube_id: int):
mlcube = Cube.get(cube_uid=mlcube_id, valid_only=False)

# Fetching associations and related benchmarks
benchmarks_associations = Cube.get_benchmarks_associations(mlcube_uid=mlcube_id)

approval_status_order = {
Status.PENDING: 0,
Status.APPROVED: 1,
Status.REJECTED: 2,
}

def assoc_sorting_key(assoc):
# lower status - first
status_order = approval_status_order.get(assoc.approval_status, -1)
# recent associations - first
date_order = -(assoc.approved_at or assoc.created_at).timestamp()
return status_order, date_order

benchmarks_associations = sorted(benchmarks_associations, key=assoc_sorting_key)
benchmarks_associations = sort_associations_display(benchmarks_associations)

benchmarks = {assoc.benchmark: Benchmark.get(assoc.benchmark) for assoc in benchmarks_associations if
assoc.benchmark}
Expand Down

0 comments on commit c225a5e

Please sign in to comment.