diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index d240398c..c164c910 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -240,6 +240,9 @@ def create_dataset(dataset: str, db_type: str) -> bool: status = Status(**d['result']) + if status.success: + __org = dataset + return status.success def add_user_to_dataset(dataset: str, email: str, is_admin: bool) -> bool: diff --git a/tests/test_vanna.py b/tests/test_vanna.py index d2056dfd..888bc8af 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -199,4 +199,11 @@ def test_remove_training_data(): rv = vn.remove_training_data(row['id']) assert rv == True - assert vn.get_training_data().shape[0] == 2-index \ No newline at end of file + assert vn.get_training_data().shape[0] == 2-index + +def test_create_dataset_and_add_user(): + created = vn.create_dataset('test_org2', 'Snowflake') + assert created == True + + added = vn.add_user_to_dataset(dataset='test_org2', email="user5@example.com", is_admin=False) + assert added == True \ No newline at end of file