-
Notifications
You must be signed in to change notification settings - Fork 908
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bc36fbd
commit 64588c2
Showing
2 changed files
with
75 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
[flake8] | ||
max-line-length = 120 | ||
exclude = tests/* | ||
ignore = E501, E722, W391, F821 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,8 +85,8 @@ | |
import sqlparse | ||
from dataclasses import dataclass | ||
|
||
from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, \ | ||
FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, \ | ||
from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, DataResult, PlotlyResult, Status, \ | ||
QuestionCategory, UserEmail, UserOTP, ApiKey, OrganizationList, \ | ||
Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON | ||
from typing import List, Union, Callable, Tuple | ||
from .exceptions import ImproperlyConfigured, DependencyError, ConnectionError, OTPCodeError, SQLRemoveError, \ | ||
|
@@ -116,6 +116,7 @@ | |
_endpoint = "https://ask.vanna.ai/rpc" | ||
_unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc" | ||
|
||
|
||
def __unauthenticated_rpc_call(method, params): | ||
headers = { | ||
'Content-Type': 'application/json', | ||
|
@@ -160,9 +161,11 @@ def __rpc_call(method, params): | |
response = requests.post(_endpoint, headers=headers, data=json.dumps(data)) | ||
return response.json() | ||
|
||
|
||
def __dataclass_to_dict(obj): | ||
return dataclasses.asdict(obj) | ||
|
||
|
||
def get_api_key(email: str, otp_code: Union[str, None] = None) -> str: | ||
""" | ||
**Example:** | ||
|
@@ -238,7 +241,9 @@ def set_api_key(key: str) -> None: | |
models = get_models() | ||
|
||
if len(models) == 0: | ||
raise ConnectionError("There was an error communicating with the Vanna.AI API. Please try again or contact [email protected]") | ||
raise ConnectionError( | ||
"There was an error communicating with the Vanna.AI API. Please try again or contact [email protected]") | ||
|
||
|
||
def get_models() -> List[str]: | ||
""" | ||
|
@@ -356,6 +361,7 @@ def update_model_visibility(public: bool) -> bool: | |
|
||
return status.success | ||
|
||
|
||
def _set_org(org: str) -> None: | ||
global __org | ||
|
||
|
@@ -505,6 +511,7 @@ def add_documentation(documentation: str) -> bool: | |
|
||
return status.success | ||
|
||
|
||
@dataclass | ||
class TrainingPlanItem: | ||
item_type: str | ||
|
@@ -544,7 +551,7 @@ def __init__(self, plan: List[TrainingPlanItem]): | |
|
||
def __str__(self): | ||
return "\n".join(self.get_summary()) | ||
|
||
def __repr__(self): | ||
return self.__str__() | ||
|
||
|
@@ -584,7 +591,6 @@ def remove_item(self, item: str): | |
self._plan.remove(plan_item) | ||
break | ||
|
||
|
||
|
||
def __get_databases() -> List[str]: | ||
try: | ||
|
@@ -594,15 +600,16 @@ def __get_databases() -> List[str]: | |
df_databases = run_sql("SHOW DATABASES") | ||
except: | ||
return [] | ||
|
||
return df_databases['DATABASE_NAME'].unique().tolist() | ||
|
||
|
||
def __get_information_schema_tables(database: str) -> pd.DataFrame: | ||
df_tables = run_sql(f'SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES') | ||
|
||
return df_tables | ||
|
||
|
||
|
||
def get_training_plan_experimental(filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True) -> TrainingPlan: | ||
""" | ||
**EXPERIMENTAL** : This method is experimental and may change in future versions. | ||
|
@@ -625,15 +632,18 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No | |
if use_historical_queries: | ||
try: | ||
print("Trying query history") | ||
df_history = run_sql(""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""") | ||
df_history = run_sql( | ||
""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""") | ||
|
||
df_history_filtered = df_history.query('ROWS_PRODUCED > 1') | ||
if filter_databases is not None: | ||
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_databases])) | ||
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply( | ||
lambda x: any(s in x for s in [s.lower() for s in filter_databases])) | ||
df_history_filtered = df_history_filtered[mask] | ||
|
||
if filter_schemas is not None: | ||
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_schemas])) | ||
mask = df_history_filtered['QUERY_TEXT'].str.lower().apply( | ||
lambda x: any(s in x for s in [s.lower() for s in filter_schemas])) | ||
df_history_filtered = df_history_filtered[mask] | ||
|
||
for query in df_history_filtered.sample(10)['QUERY_TEXT'].unique().tolist(): | ||
|
@@ -648,7 +658,7 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No | |
print(e) | ||
|
||
databases = __get_databases() | ||
|
||
for database in databases: | ||
if filter_databases is not None and database not in filter_databases: | ||
continue | ||
|
@@ -674,15 +684,17 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No | |
for table in tables: | ||
df_columns_filtered_to_table = df_columns_filtered_to_schema.query(f"TABLE_NAME == '{table}'") | ||
doc = f"The following columns are in the {table} table in the {database} database:\n\n" | ||
doc += df_columns_filtered_to_table[["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "COMMENT"]].to_markdown() | ||
|
||
doc += df_columns_filtered_to_table[ | ||
["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", | ||
"COMMENT"]].to_markdown() | ||
|
||
plan._plan.append(TrainingPlanItem( | ||
item_type=TrainingPlanItem.ITEM_TYPE_IS, | ||
item_group=f"{database}.{schema}", | ||
item_name=table, | ||
item_value=doc | ||
)) | ||
|
||
except Exception as e: | ||
print(e) | ||
pass | ||
|
@@ -711,36 +723,36 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No | |
# print("Trying INFORMATION_SCHEMA.TABLES") | ||
# df = run_sql("SELECT * FROM INFORMATION_SCHEMA.TABLES") | ||
|
||
# breakpoint() | ||
|
||
# try: | ||
# print("Trying SCHEMATA") | ||
# df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA") | ||
|
||
# for schema in df_schemata.schema_name.unique(): | ||
# df = run_sql(f"SELECT * FROM {schema}.information_schema.tables") | ||
|
||
# for table in df.table_name.unique(): | ||
# plan._plan.append(TrainingPlanItem( | ||
# item_type=TrainingPlanItem.ITEM_TYPE_IS, | ||
# item_group=schema, | ||
# item_name=table, | ||
# item_value=None | ||
# )) | ||
|
||
# try: | ||
# ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") | ||
|
||
# plan._plan.append(TrainingPlanItem( | ||
# item_type=TrainingPlanItem.ITEM_TYPE_DDL, | ||
# item_group=schema, | ||
# item_name=None, | ||
# item_value=ddl_df.iloc[0, 0] | ||
# )) | ||
# except: | ||
# pass | ||
# except: | ||
# pass | ||
# breakpoint() | ||
|
||
# try: | ||
# print("Trying SCHEMATA") | ||
# df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA") | ||
|
||
# for schema in df_schemata.schema_name.unique(): | ||
# df = run_sql(f"SELECT * FROM {schema}.information_schema.tables") | ||
|
||
# for table in df.table_name.unique(): | ||
# plan._plan.append(TrainingPlanItem( | ||
# item_type=TrainingPlanItem.ITEM_TYPE_IS, | ||
# item_group=schema, | ||
# item_name=table, | ||
# item_value=None | ||
# )) | ||
|
||
# try: | ||
# ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") | ||
|
||
# plan._plan.append(TrainingPlanItem( | ||
# item_type=TrainingPlanItem.ITEM_TYPE_DDL, | ||
# item_group=schema, | ||
# item_name=None, | ||
# item_value=ddl_df.iloc[0, 0] | ||
# )) | ||
# except: | ||
# pass | ||
# except: | ||
# pass | ||
|
||
return plan | ||
|
||
|
@@ -753,7 +765,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: | |
vn.train() | ||
``` | ||
Train Vanna.AI on a question and its corresponding SQL query. | ||
Train Vanna.AI on a question and its corresponding SQL query. | ||
If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database. | ||
If you call it with the sql argument, it's equivalent to [`add_sql()`][vanna.add_sql]. | ||
If you call it with the ddl argument, it's equivalent to [`add_ddl()`][vanna.add_ddl]. | ||
|
@@ -820,7 +832,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: | |
print("Not able to add sql.") | ||
return False | ||
return False | ||
|
||
if plan: | ||
for item in plan._plan: | ||
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: | ||
|
@@ -915,7 +927,7 @@ def remove_sql(question: str) -> bool: | |
d = __rpc_call(method="remove_sql", params=params) | ||
|
||
if 'result' not in d: | ||
raise Exception(f"Error removing SQL") | ||
raise Exception("Error removing SQL") | ||
return False | ||
|
||
status = Status(**d['result']) | ||
|
@@ -943,7 +955,7 @@ def remove_training_data(id: str) -> bool: | |
d = __rpc_call(method="remove_training_data", params=params) | ||
|
||
if 'result' not in d: | ||
raise APIError(f"Error removing training data") | ||
raise APIError("Error removing training data") | ||
|
||
status = Status(**d['result']) | ||
|
||
|
@@ -1110,11 +1122,11 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai | |
|
||
if print_results: | ||
try: | ||
Code = __import__('IPython.display', fromlist=['Code']).Code | ||
display(Code(sql)) | ||
except Exception as e: | ||
Code = __import__('IPython.display', fromlist=['Code']).Code | ||
display(Code(sql)) | ||
except Exception: | ||
print(sql) | ||
|
||
if run_sql is None: | ||
print("If you want to run the SQL query, provide a vn.run_sql function.") | ||
|
||
|
@@ -1130,11 +1142,11 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai | |
try: | ||
display = __import__('IPython.display', fromlist=['display']).display | ||
display(df) | ||
except Exception as e: | ||
except Exception: | ||
print(df) | ||
|
||
if len(df) > 0 and auto_train: | ||
add_sql(question=question, sql=sql, tag=types.QuestionCategory.SQL_RAN) | ||
add_sql(question=question, sql=sql, tag=QuestionCategory.SQL_RAN) | ||
|
||
try: | ||
plotly_code = generate_plotly_code(question=question, sql=sql, df=df) | ||
|
@@ -1145,7 +1157,7 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai | |
Image = __import__('IPython.display', fromlist=['Image']).Image | ||
img_bytes = fig.to_image(format="png", scale=2) | ||
display(Image(img_bytes)) | ||
except Exception as e: | ||
except Exception: | ||
fig.show() | ||
|
||
if generate_followups: | ||
|
@@ -1159,10 +1171,9 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai | |
display = __import__('IPython.display', fromlist=['display']).display | ||
Markdown = __import__('IPython.display', fromlist=['Markdown']).Markdown | ||
display(Markdown(md)) | ||
except Exception as e: | ||
except Exception: | ||
print(md) | ||
|
||
|
||
if print_results: | ||
return None | ||
else: | ||
|
@@ -1190,7 +1201,8 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai | |
return sql, None, None, None | ||
|
||
|
||
def generate_plotly_code(question: Union[str, None], sql: Union[str, None], df: pd.DataFrame, chart_instructions: Union[str, None] = None) -> str: | ||
def generate_plotly_code(question: Union[str, None], sql: Union[str, None], df: pd.DataFrame, | ||
chart_instructions: Union[str, None] = None) -> str: | ||
""" | ||
**Example:** | ||
```python | ||
|
@@ -1333,6 +1345,7 @@ def generate_explanation(sql: str) -> str: | |
|
||
return explanation.explanation | ||
|
||
|
||
def generate_question(sql: str) -> str: | ||
""" | ||
|
@@ -1426,6 +1439,7 @@ def get_training_data() -> pd.DataFrame: | |
|
||
return df | ||
|
||
|
||
def connect_to_sqlite(url: str): | ||
""" | ||
Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] | ||
|
@@ -1458,6 +1472,7 @@ def run_sql_sqlite(sql: str): | |
global run_sql | ||
run_sql = run_sql_sqlite | ||
|
||
|
||
def connect_to_snowflake(account: str, username: str, password: str, database: str, role: Union[str, None] = None): | ||
""" | ||
Connect to Snowflake using the Snowflake connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] | ||
|
@@ -1487,7 +1502,7 @@ def connect_to_snowflake(account: str, username: str, password: str, database: s | |
snowflake = __import__('snowflake.connector') | ||
except ImportError: | ||
raise DependencyError("You need to install required dependencies to execute this method, run command:" | ||
" \npip install vanna[snowflake]") | ||
" \npip install vanna[snowflake]") | ||
|
||
if username == 'my-username': | ||
username_env = os.getenv('SNOWFLAKE_USERNAME') | ||
|
@@ -1575,7 +1590,7 @@ def connect_to_postgres(host: str = None, dbname: str = None, user: str = None, | |
import psycopg2.extras | ||
except ImportError: | ||
raise DependencyError("You need to install required dependencies to execute this method," | ||
" run command: \npip install vanna[postgres]") | ||
" run command: \npip install vanna[postgres]") | ||
|
||
if not host: | ||
host = os.getenv('HOST') | ||
|