diff --git a/moshi/pyproject.toml b/moshi/pyproject.toml index 32b74b3..36775cd 100644 --- a/moshi/pyproject.toml +++ b/moshi/pyproject.toml @@ -19,6 +19,10 @@ license = {text = "MIT"} dynamic = ["version"] readme = "README.md" +[project.scripts] +moshi-server = "moshi.server:main" +moshi-client = "moshi.client:main" + [tool.setuptools.dynamic] version = {attr = "moshi.__version__"} diff --git a/moshi_mlx/moshi_mlx/local.py b/moshi_mlx/moshi_mlx/local.py index 812442f..7b4222c 100644 --- a/moshi_mlx/moshi_mlx/local.py +++ b/moshi_mlx/moshi_mlx/local.py @@ -250,7 +250,7 @@ async def go(): pass -def main(printer: AnyPrinter): +def main(): parser = argparse.ArgumentParser() parser.add_argument("--tokenizer", type=str) parser.add_argument("--moshi-weight", type=str) @@ -276,6 +276,12 @@ def main(printer: AnyPrinter): server_to_client = multiprocessing.Queue() printer_q = multiprocessing.Queue() + printer: AnyPrinter + if sys.stdout.isatty(): + printer = Printer() + else: + printer = RawPrinter() + # Create two processes subprocess_args = printer_q, client_to_server, server_to_client, args p1 = multiprocessing.Process(target=client, args=subprocess_args) @@ -372,9 +378,4 @@ def main(printer: AnyPrinter): if __name__ == "__main__": - printer: AnyPrinter - if sys.stdout.isatty(): - printer = Printer() - else: - printer = RawPrinter() - main(printer) + main() diff --git a/moshi_mlx/pyproject.toml b/moshi_mlx/pyproject.toml index ccf12f3..943067c 100644 --- a/moshi_mlx/pyproject.toml +++ b/moshi_mlx/pyproject.toml @@ -19,6 +19,9 @@ license = {text = "MIT"} dynamic = ["version"] readme = "README.md" +[project.scripts] +moshi-local = "moshi_mlx.local:main" +moshi-local-web = "moshi_mlx.local_web:main" [build-system] requires = ["setuptools"]