diff --git a/cli/medperf/web_ui/benchmarks/routes.py b/cli/medperf/web_ui/benchmarks/routes.py index 2ac6c2d8f..d31a1e60e 100644 --- a/cli/medperf/web_ui/benchmarks/routes.py +++ b/cli/medperf/web_ui/benchmarks/routes.py @@ -9,7 +9,7 @@ from medperf.entities.cube import Cube from medperf.account_management import get_medperf_user_data from medperf.enums import Status -from medperf.web_ui.common import templates +from medperf.web_ui.common import templates, sort_associations_display router = APIRouter() logger = logging.getLogger(__name__) @@ -37,23 +37,9 @@ def benchmark_detail_ui(request: Request, benchmark_id: int): datasets_associations = Benchmark.get_datasets_associations(benchmark_uid=benchmark_id) models_associations = Benchmark.get_models_associations(benchmark_uid=benchmark_id) - approval_status_order = { - Status.PENDING: 0, - Status.APPROVED: 1, - Status.REJECTED: 2, - } + datasets_associations = sort_associations_display(datasets_associations) + models_associations = sort_associations_display(models_associations) - def assoc_sorting_key(assoc): - # lower status - first - status_order = approval_status_order.get(assoc.approval_status, -1) - # recent associations - first - date_order = -(assoc.approved_at or assoc.created_at).timestamp() - return status_order, date_order - - datasets_associations = sorted(datasets_associations, key=assoc_sorting_key) - models_associations = sorted(models_associations, key=assoc_sorting_key) - - # Fetch datasets and models information datasets = {assoc.dataset: Dataset.get(assoc.dataset) for assoc in datasets_associations if assoc.dataset} models = {assoc.model_mlcube: Cube.get(assoc.model_mlcube) for assoc in models_associations if assoc.model_mlcube} diff --git a/cli/medperf/web_ui/common.py b/cli/medperf/web_ui/common.py index 9583d0ed8..b5a34d9b6 100644 --- a/cli/medperf/web_ui/common.py +++ b/cli/medperf/web_ui/common.py @@ -5,6 +5,9 @@ from fastapi.requests import Request +from medperf.entities.association import Association +from medperf.enums import Status + templates = Jinja2Templates(directory=str(resources.path("medperf.web_ui", "templates"))) logger = logging.getLogger(__name__) @@ -18,4 +21,30 @@ async def custom_exception_handler(request: Request, exc: Exception): context = {"request": request, "exception": exc} # Return a detailed error page - return templates.TemplateResponse("error.html", context, status_code=500) \ No newline at end of file + return templates.TemplateResponse("error.html", context, status_code=500) + + +def sort_associations_display(associations: list[Association]) -> list[Association]: + """ + Sorts associations: + - by approval status (pending, approved, rejected) + - by date (recent first) + Args: + associations: associations to sort + Returns: sorted list + """ + + approval_status_order = { + Status.PENDING: 0, + Status.APPROVED: 1, + Status.REJECTED: 2, + } + + def assoc_sorting_key(assoc): + # lower status - first + status_order = approval_status_order.get(assoc.approval_status, -1) + # recent associations - first + date_order = -(assoc.approved_at or assoc.created_at).timestamp() + return status_order, date_order + + return sorted(associations, key=assoc_sorting_key) diff --git a/cli/medperf/web_ui/datasets/routes.py b/cli/medperf/web_ui/datasets/routes.py index 5d47814de..c4931ec9e 100644 --- a/cli/medperf/web_ui/datasets/routes.py +++ b/cli/medperf/web_ui/datasets/routes.py @@ -11,7 +11,7 @@ from medperf.entities.dataset import Dataset from medperf.entities.benchmark import Benchmark from medperf.enums import Status -from medperf.web_ui.common import templates +from medperf.web_ui.common import templates, sort_associations_display router = APIRouter() logger = logging.getLogger(__name__) @@ -43,25 +43,9 @@ def dataset_detail_ui(request: Request, dataset_id: int): prep_cube = Cube.get(cube_uid=dataset.data_preparation_mlcube) prep_cube_name = prep_cube.name if prep_cube else "Unknown" - # Fetching associations and related benchmarks benchmark_associations = Dataset.get_benchmarks_associations(dataset_uid=dataset_id) + benchmark_associations = sort_associations_display(benchmark_associations) - approval_status_order = { - Status.PENDING: 0, - Status.APPROVED: 1, - Status.REJECTED: 2, - } - - def assoc_sorting_key(assoc): - # lower status - first - status_order = approval_status_order.get(assoc.approval_status, -1) - # recent associations - first - date_order = -(assoc.approved_at or assoc.created_at).timestamp() - return status_order, date_order - - benchmark_associations = sorted(benchmark_associations, key=assoc_sorting_key) - - # Fetch benchmarks information benchmarks = {assoc.benchmark: Benchmark.get(assoc.benchmark) for assoc in benchmark_associations if assoc.benchmark} diff --git a/cli/medperf/web_ui/mlcubes/routes.py b/cli/medperf/web_ui/mlcubes/routes.py index 02390c181..ecacbb4e2 100644 --- a/cli/medperf/web_ui/mlcubes/routes.py +++ b/cli/medperf/web_ui/mlcubes/routes.py @@ -9,7 +9,7 @@ from medperf.entities.cube import Cube from medperf.entities.benchmark import Benchmark from medperf.enums import Status -from medperf.web_ui.common import templates +from medperf.web_ui.common import templates, sort_associations_display router = APIRouter() logger = logging.getLogger(__name__) @@ -38,23 +38,8 @@ def mlcubes_ui(request: Request, local_only: bool = False, mine_only: bool = Fal def mlcube_detail_ui(request: Request, mlcube_id: int): mlcube = Cube.get(cube_uid=mlcube_id, valid_only=False) - # Fetching associations and related benchmarks benchmarks_associations = Cube.get_benchmarks_associations(mlcube_uid=mlcube_id) - - approval_status_order = { - Status.PENDING: 0, - Status.APPROVED: 1, - Status.REJECTED: 2, - } - - def assoc_sorting_key(assoc): - # lower status - first - status_order = approval_status_order.get(assoc.approval_status, -1) - # recent associations - first - date_order = -(assoc.approved_at or assoc.created_at).timestamp() - return status_order, date_order - - benchmarks_associations = sorted(benchmarks_associations, key=assoc_sorting_key) + benchmarks_associations = sort_associations_display(benchmarks_associations) benchmarks = {assoc.benchmark: Benchmark.get(assoc.benchmark) for assoc in benchmarks_associations if assoc.benchmark}