diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index ae3c06ab1..8d2dde1df 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -85,8 +85,8 @@ import sqlparse from dataclasses import dataclass -from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, \ - FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, \ +from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, DataResult, PlotlyResult, Status, \ + QuestionCategory, UserEmail, UserOTP, ApiKey, OrganizationList, \ Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON from typing import List, Union, Callable, Tuple from .exceptions import ImproperlyConfigured, DependencyError, ConnectionError, OTPCodeError, SQLRemoveError, \ @@ -116,6 +116,7 @@ _endpoint = "https://ask.vanna.ai/rpc" _unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc" + def __unauthenticated_rpc_call(method, params): headers = { 'Content-Type': 'application/json', @@ -160,9 +161,11 @@ def __rpc_call(method, params): response = requests.post(_endpoint, headers=headers, data=json.dumps(data)) return response.json() + def __dataclass_to_dict(obj): return dataclasses.asdict(obj) + def get_api_key(email: str, otp_code: Union[str, None] = None) -> str: """ **Example:** @@ -238,7 +241,9 @@ def set_api_key(key: str) -> None: models = get_models() if len(models) == 0: - raise ConnectionError("There was an error communicating with the Vanna.AI API. Please try again or contact support@vanna.ai") + raise ConnectionError( + "There was an error communicating with the Vanna.AI API. Please try again or contact support@vanna.ai") + def get_models() -> List[str]: """ @@ -356,6 +361,7 @@ def update_model_visibility(public: bool) -> bool: return status.success + def _set_org(org: str) -> None: global __org @@ -505,6 +511,7 @@ def add_documentation(documentation: str) -> bool: return status.success + @dataclass class TrainingPlanItem: item_type: str @@ -544,7 +551,7 @@ def __init__(self, plan: List[TrainingPlanItem]): def __str__(self): return "\n".join(self.get_summary()) - + def __repr__(self): return self.__str__() @@ -584,7 +591,6 @@ def remove_item(self, item: str): self._plan.remove(plan_item) break - def __get_databases() -> List[str]: try: @@ -594,12 +600,12 @@ def __get_databases() -> List[str]: df_databases = run_sql("SHOW DATABASES") except: return [] - + return df_databases['DATABASE_NAME'].unique().tolist() + def __get_information_schema_tables(database: str) -> pd.DataFrame: df_tables = run_sql(f'SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES') - return df_tables @@ -625,15 +631,18 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No if use_historical_queries: try: print("Trying query history") - df_history = run_sql(""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""") + df_history = run_sql( + """ select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""") df_history_filtered = df_history.query('ROWS_PRODUCED > 1') if filter_databases is not None: - mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_databases])) + mask = df_history_filtered['QUERY_TEXT'].str.lower().apply( + lambda x: any(s in x for s in [s.lower() for s in filter_databases])) df_history_filtered = df_history_filtered[mask] if filter_schemas is not None: - mask = df_history_filtered['QUERY_TEXT'].str.lower().apply(lambda x: any(s in x for s in [s.lower() for s in filter_schemas])) + mask = df_history_filtered['QUERY_TEXT'].str.lower().apply( + lambda x: any(s in x for s in [s.lower() for s in filter_schemas])) df_history_filtered = df_history_filtered[mask] for query in df_history_filtered.sample(10)['QUERY_TEXT'].unique().tolist(): @@ -648,7 +657,7 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No print(e) databases = __get_databases() - + for database in databases: if filter_databases is not None and database not in filter_databases: continue @@ -674,15 +683,17 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No for table in tables: df_columns_filtered_to_table = df_columns_filtered_to_schema.query(f"TABLE_NAME == '{table}'") doc = f"The following columns are in the {table} table in the {database} database:\n\n" - doc += df_columns_filtered_to_table[["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", "COMMENT"]].to_markdown() - + doc += df_columns_filtered_to_table[ + ["TABLE_CATALOG", "TABLE_SCHEMA", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE", + "COMMENT"]].to_markdown() + plan._plan.append(TrainingPlanItem( item_type=TrainingPlanItem.ITEM_TYPE_IS, item_group=f"{database}.{schema}", item_name=table, item_value=doc )) - + except Exception as e: print(e) pass @@ -711,36 +722,36 @@ def get_training_plan_experimental(filter_databases: Union[List[str], None] = No # print("Trying INFORMATION_SCHEMA.TABLES") # df = run_sql("SELECT * FROM INFORMATION_SCHEMA.TABLES") - # breakpoint() - - # try: - # print("Trying SCHEMATA") - # df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA") - - # for schema in df_schemata.schema_name.unique(): - # df = run_sql(f"SELECT * FROM {schema}.information_schema.tables") - - # for table in df.table_name.unique(): - # plan._plan.append(TrainingPlanItem( - # item_type=TrainingPlanItem.ITEM_TYPE_IS, - # item_group=schema, - # item_name=table, - # item_value=None - # )) - - # try: - # ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") - - # plan._plan.append(TrainingPlanItem( - # item_type=TrainingPlanItem.ITEM_TYPE_DDL, - # item_group=schema, - # item_name=None, - # item_value=ddl_df.iloc[0, 0] - # )) - # except: - # pass - # except: - # pass + # breakpoint() + + # try: + # print("Trying SCHEMATA") + # df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA") + + # for schema in df_schemata.schema_name.unique(): + # df = run_sql(f"SELECT * FROM {schema}.information_schema.tables") + + # for table in df.table_name.unique(): + # plan._plan.append(TrainingPlanItem( + # item_type=TrainingPlanItem.ITEM_TYPE_IS, + # item_group=schema, + # item_name=table, + # item_value=None + # )) + + # try: + # ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") + + # plan._plan.append(TrainingPlanItem( + # item_type=TrainingPlanItem.ITEM_TYPE_DDL, + # item_group=schema, + # item_name=None, + # item_value=ddl_df.iloc[0, 0] + # )) + # except: + # pass + # except: + # pass return plan @@ -753,7 +764,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: vn.train() ``` - Train Vanna.AI on a question and its corresponding SQL query. + Train Vanna.AI on a question and its corresponding SQL query. If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database. If you call it with the sql argument, it's equivalent to [`add_sql()`][vanna.add_sql]. If you call it with the ddl argument, it's equivalent to [`add_ddl()`][vanna.add_ddl]. @@ -820,7 +831,7 @@ def train(question: str = None, sql: str = None, ddl: str = None, documentation: print("Not able to add sql.") return False return False - + if plan: for item in plan._plan: if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: @@ -915,7 +926,7 @@ def remove_sql(question: str) -> bool: d = __rpc_call(method="remove_sql", params=params) if 'result' not in d: - raise Exception(f"Error removing SQL") + raise Exception("Error removing SQL") return False status = Status(**d['result']) @@ -943,7 +954,7 @@ def remove_training_data(id: str) -> bool: d = __rpc_call(method="remove_training_data", params=params) if 'result' not in d: - raise APIError(f"Error removing training data") + raise APIError("Error removing training data") status = Status(**d['result']) @@ -1110,11 +1121,11 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai if print_results: try: - Code = __import__('IPython.display', fromlist=['Code']).Code - display(Code(sql)) - except Exception as e: + Code = __import__('IPython.display', fromlist=['Code']).Code + display(Code(sql)) + except Exception: print(sql) - + if run_sql is None: print("If you want to run the SQL query, provide a vn.run_sql function.") @@ -1130,11 +1141,11 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai try: display = __import__('IPython.display', fromlist=['display']).display display(df) - except Exception as e: + except Exception: print(df) if len(df) > 0 and auto_train: - add_sql(question=question, sql=sql, tag=types.QuestionCategory.SQL_RAN) + add_sql(question=question, sql=sql, tag=QuestionCategory.SQL_RAN) try: plotly_code = generate_plotly_code(question=question, sql=sql, df=df) @@ -1145,7 +1156,7 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai Image = __import__('IPython.display', fromlist=['Image']).Image img_bytes = fig.to_image(format="png", scale=2) display(Image(img_bytes)) - except Exception as e: + except Exception: fig.show() if generate_followups: @@ -1159,10 +1170,9 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai display = __import__('IPython.display', fromlist=['display']).display Markdown = __import__('IPython.display', fromlist=['Markdown']).Markdown display(Markdown(md)) - except Exception as e: + except Exception: print(md) - if print_results: return None else: @@ -1190,7 +1200,8 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai return sql, None, None, None -def generate_plotly_code(question: Union[str, None], sql: Union[str, None], df: pd.DataFrame, chart_instructions: Union[str, None] = None) -> str: +def generate_plotly_code(question: Union[str, None], sql: Union[str, None], df: pd.DataFrame, + chart_instructions: Union[str, None] = None) -> str: """ **Example:** ```python @@ -1333,6 +1344,7 @@ def generate_explanation(sql: str) -> str: return explanation.explanation + def generate_question(sql: str) -> str: """ @@ -1426,6 +1438,7 @@ def get_training_data() -> pd.DataFrame: return df + def connect_to_sqlite(url: str): """ Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] @@ -1458,6 +1471,7 @@ def run_sql_sqlite(sql: str): global run_sql run_sql = run_sql_sqlite + def connect_to_snowflake(account: str, username: str, password: str, database: str, role: Union[str, None] = None): """ Connect to Snowflake using the Snowflake connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] @@ -1487,7 +1501,7 @@ def connect_to_snowflake(account: str, username: str, password: str, database: s snowflake = __import__('snowflake.connector') except ImportError: raise DependencyError("You need to install required dependencies to execute this method, run command:" - " \npip install vanna[snowflake]") + " \npip install vanna[snowflake]") if username == 'my-username': username_env = os.getenv('SNOWFLAKE_USERNAME') @@ -1575,7 +1589,7 @@ def connect_to_postgres(host: str = None, dbname: str = None, user: str = None, import psycopg2.extras except ImportError: raise DependencyError("You need to install required dependencies to execute this method," - " run command: \npip install vanna[postgres]") + " run command: \npip install vanna[postgres]") if not host: host = os.getenv('HOST')