diff --git a/doltcli/dolt.py b/doltcli/dolt.py index 150874d..1cfacf7 100644 --- a/doltcli/dolt.py +++ b/doltcli/dolt.py @@ -267,6 +267,8 @@ class Dolt(DoltT): """ def __init__(self, repo_dir: str, print_output: Optional[bool] = None): + # allow ~ to be used in paths + repo_dir = os.path.expanduser(repo_dir) self.repo_dir = repo_dir self._print_output = print_output or False @@ -1366,7 +1368,7 @@ def schema_export(self, table: str, filename: Optional[str] = None): args = ["schema", "export", table] if filename: - args.extend(["--filename", filename]) + args.extend([filename]) _execute(args, self.repo_dir) return True else: diff --git a/tests/test_dolt.py b/tests/test_dolt.py index e363660..6a3c628 100644 --- a/tests/test_dolt.py +++ b/tests/test_dolt.py @@ -113,6 +113,18 @@ def test_init(tmp_path): shutil.rmtree(repo_data_dir) +def test_home_path(): + path = "~/.dolt_test" + if os.path.exists(os.path.expanduser(path)): + shutil.rmtree(os.path.expanduser(path)) + os.mkdir(os.path.expanduser(path)) + # Create empty file + open(os.path.expanduser(path + "/.dolt"), "a").close() + Dolt(path) + assert os.path.exists(path) + shutil.rmtree(path) + + def test_bad_repo_path(tmp_path): bad_repo_path = tmp_path with pytest.raises(ValueError): @@ -205,10 +217,10 @@ def test_merge_conflict(create_test_table: Tuple[Dolt, str]): with pytest.raises(DoltException): repo.merge("other", message_merge) - #commits = list(repo.log().values()) - #head_of_main = commits[0] + # commits = list(repo.log().values()) + # head_of_main = commits[0] - #assert head_of_main.message == message_two + # assert head_of_main.message == message_two def test_dolt_log(create_test_table: Tuple[Dolt, str]): @@ -400,10 +412,7 @@ def test_branch(create_test_table: Tuple[Dolt, str]): repo.checkout("dosac", checkout_branch=True) repo.checkout("main") next_active_branch, next_branches = repo.branch() - assert ( - set(branch.name for branch in next_branches) == {"main", "dosac"} - and next_active_branch.name == "main" - ) + assert set(branch.name for branch in next_branches) == {"main", "dosac"} and next_active_branch.name == "main" repo.checkout("dosac") different_active_branch, _ = repo.branch() @@ -552,17 +561,13 @@ def test_sql(create_test_table: Tuple[Dolt, str]): def test_sql_json(create_test_table: Tuple[Dolt, str]): repo, test_table = create_test_table - result = repo.sql( - query="SELECT * FROM `{table}`".format(table=test_table), result_format="json" - )["rows"] + result = repo.sql(query="SELECT * FROM `{table}`".format(table=test_table), result_format="json")["rows"] _verify_against_base_rows(result) def test_sql_csv(create_test_table: Tuple[Dolt, str]): repo, test_table = create_test_table - result = repo.sql( - query="SELECT * FROM `{table}`".format(table=test_table), result_format="csv" - ) + result = repo.sql(query="SELECT * FROM `{table}`".format(table=test_table), result_format="csv") _verify_against_base_rows(result) @@ -604,10 +609,7 @@ def test_config_global(init_empty_test_repo: Dolt): Dolt.config_global(add=True, name="user.name", value=test_username) Dolt.config_global(add=True, name="user.email", value=test_email) updated_config = Dolt.config_global(list=True) - assert ( - updated_config["user.name"] == test_username - and updated_config["user.email"] == test_email - ) + assert updated_config["user.name"] == test_username and updated_config["user.email"] == test_email Dolt.config_global(add=True, name="user.name", value=current_global_config["user.name"]) Dolt.config_global(add=True, name="user.email", value=current_global_config["user.email"]) reset_config = Dolt.config_global(list=True) @@ -623,9 +625,7 @@ def test_config_local(init_empty_test_repo: Dolt): repo.config_local(add=True, name="user.email", value=test_email) local_config = repo.config_local(list=True) global_config = Dolt.config_global(list=True) - assert ( - local_config["user.name"] == test_username and local_config["user.email"] == test_email - ) + assert local_config["user.name"] == test_username and local_config["user.email"] == test_email assert global_config["user.name"] == current_global_config["user.name"] assert global_config["user.email"] == current_global_config["user.email"] @@ -677,18 +677,14 @@ def test_clone_new_dir(tmp_path): def test_dolt_sql_csv(init_empty_test_repo: Dolt): dolt = init_empty_test_repo write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True) - result = dolt.sql( - "SELECT `name` as name, `id` as id FROM test_table ORDER BY id", result_format="csv" - ) + result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ORDER BY id", result_format="csv") compare_rows_helper(BASE_TEST_ROWS, result) def test_dolt_sql_json(init_empty_test_repo: Dolt): dolt = init_empty_test_repo write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True) - result = dolt.sql( - "SELECT `name` as name, `id` as id FROM test_table ", result_format="json" - ) + result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ", result_format="json") # JSON return value preserves some type information, we cast back to a string for row in result["rows"]: row["id"] = str(row["id"]) @@ -700,9 +696,7 @@ def test_dolt_sql_file(init_empty_test_repo: Dolt): with tempfile.NamedTemporaryFile() as f: write_rows(dolt, "test_table", BASE_TEST_ROWS, commit=True) - result = dolt.sql( - "SELECT `name` as name, `id` as id FROM test_table ", result_file=f.name - ) + result = dolt.sql("SELECT `name` as name, `id` as id FROM test_table ", result_file=f.name) res = read_csv_to_dict(f.name) compare_rows_helper(BASE_TEST_ROWS, res)