Skip to content

Commit

Permalink
fix minari list table when unsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Nov 7, 2024
1 parent 70f9029 commit 701b617
Showing 1 changed file with 46 additions and 54 deletions.
100 changes: 46 additions & 54 deletions minari/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,16 @@ def _version_callback(value: bool):

def _show_dataset_table(datasets: Dict[str, Dict[str, Any]], table_title: str):
# Collect compatible versions of each dataset
dataset_versions = defaultdict(list)
dataset_hierarchy = defaultdict(lambda: defaultdict(list))

display_versions = False
for dataset_id in datasets.keys():
namespace, dataset_name, version = parse_dataset_id(dataset_id)
dataset_id_versionless = gen_dataset_id(namespace, dataset_name)
dataset_versions[dataset_id_versionless].append(version)

# "Versions" column is only displayed if there are multiple versions
display_versions = any([len(x) > 1 for x in dataset_versions.values()])
dataset_hierarchy[namespace][dataset_name].append(version)
display_versions = display_versions or len(dataset_hierarchy[namespace][dataset_name]) > 1

# Build the Rich Table
table = Table(title=table_title)

table.add_column("Name", justify="left", style="cyan", no_wrap=True)

if display_versions:
Expand All @@ -56,54 +53,49 @@ def _show_dataset_table(datasets: Dict[str, Dict[str, Any]], table_title: str):
table.add_column("Dataset Size", justify="left", style="green", no_wrap=True)
table.add_column("Author", justify="left", style="magenta", no_wrap=True)

previous_namespace = None

for dataset_prefix, versions in dataset_versions.items():
dataset_id = f"{dataset_prefix}-v{max(versions)}"
dst_metadata = datasets[dataset_id]
author = dst_metadata.get("author", "Unknown")
if not isinstance(author, str) and isinstance(author, Iterable):
author = ", ".join(author)
dataset_size = dst_metadata.get("dataset_size", "Unknown")
if dataset_size != "Unknown":
dataset_size = f"{str(dataset_size)} MB"
author_email = dst_metadata.get("author_email", "Unknown")
if not isinstance(author_email, str) and isinstance(author_email, Iterable):
author_email = ", ".join(author_email)

assert isinstance(dst_metadata["dataset_id"], str)
assert isinstance(author, str)
assert isinstance(author_email, str)

docs_url = dst_metadata.get("docs_url", None)
compatible_versions = ", ".join(
[f"v{x}" for x in sorted(versions, reverse=True)]
)

if docs_url is not None:
dataset_id_text = f"[link={docs_url}]{dataset_id}[/link]"
else:
dataset_id_text = dataset_id

namespace, _, _ = parse_dataset_id(dataset_id)

if namespace != previous_namespace:
table.add_section()
previous_namespace = namespace

# Build the current table row
rows = []
rows.append(dataset_id_text)

if display_versions:
rows.append(compatible_versions)

rows.append(str(dst_metadata["total_episodes"]))
rows.append(str(dst_metadata["total_steps"]))
rows.append(dataset_size)
rows.append(author)
table.add_row(*rows)
for namespace, namespace_datasets in dataset_hierarchy.items():
for dataset_name, versions in namespace_datasets.items():
dataset_id = gen_dataset_id(namespace, dataset_name, max(versions))

dst_metadata = datasets[dataset_id]
author = dst_metadata.get("author", "Unknown")
if not isinstance(author, str) and isinstance(author, Iterable):
author = ", ".join(author)
dataset_size = dst_metadata.get("dataset_size", "Unknown")
if dataset_size != "Unknown":
dataset_size = f"{str(dataset_size)} MB"
author_email = dst_metadata.get("author_email", "Unknown")
if not isinstance(author_email, str) and isinstance(author_email, Iterable):
author_email = ", ".join(author_email)

assert isinstance(dst_metadata["dataset_id"], str)
assert isinstance(author, str)
assert isinstance(author_email, str)

docs_url = dst_metadata.get("docs_url", None)
compatible_versions = ", ".join(
[f"v{x}" for x in sorted(versions, reverse=True)]
)

if docs_url is not None:
dataset_id_text = f"[link={docs_url}]{dataset_id}[/link]"
else:
dataset_id_text = dataset_id

# Build the current table row
rows = []
rows.append(dataset_id_text)

if display_versions:
rows.append(compatible_versions)

rows.append(str(dst_metadata["total_episodes"]))
rows.append(str(dst_metadata["total_steps"]))
rows.append(dataset_size)
rows.append(author)
table.add_row(*rows)

table.add_section()
print(table)


Expand Down

0 comments on commit 701b617

Please sign in to comment.