From fc793e19d3c631af3f82cdd11feaee6c543829d4 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Wed, 19 Jul 2023 00:36:25 -0400 Subject: [PATCH] get_all_questions --- src/vanna/__init__.py | 42 +++++++++--------------------------------- src/vanna/types.py | 4 ++++ 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index e7193dfb..5ce32a0e 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -12,7 +12,7 @@ import plotly import plotly.express as px import plotly.graph_objects as go -from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember +from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON from typing import List, Dict, Any, Union, Optional, Callable, Tuple import warnings import traceback @@ -783,55 +783,31 @@ def generate_question(sql: str) -> str: return question.question -def get_flagged_questions() -> QuestionList: +def get_all_questions() -> pd.DataFrame: """ ## Example ```python - questions = vn.get_flagged_questions() + questions = vn.get_all_questions() ``` - Get a list of flagged questions from the Vanna.AI API. + Get a list of questions from the Vanna.AI API. Returns: - List[FullQuestionDocument] or None: The list of flagged questions, or None if an error occurred. + pd.DataFrame or None: The list of questions, or None if an error occurred. """ # params = [Question(question="")] params = [] - d = __rpc_call(method="get_flagged_questions", params=params) + d = __rpc_call(method="get_all_questions", params=params) if 'result' not in d: return None # Load the result into a dataclass - flagged_questions = QuestionList(**d['result']) + all_questions = DataFrameJSON(**d['result']) - return flagged_questions + df = pd.read_json(all_questions.data) -def get_accuracy_stats() -> AccuracyStats: - """ - - ## Example - ```python - vn.get_accuracy_stats() - ``` - - Get the accuracy statistics from the Vanna.AI API. - - Returns: - dict or None: The accuracy statistics, or None if an error occurred. - - """ - params = [] - - d = __rpc_call(method="get_accuracy_stats", params=params) - - if 'result' not in d: - return None - - # Load the result into a dataclass - accuracy_stats = AccuracyStats(**d['result']) - - return accuracy_stats \ No newline at end of file + return df diff --git a/src/vanna/types.py b/src/vanna/types.py index ef7efe99..1c6f0d34 100644 --- a/src/vanna/types.py +++ b/src/vanna/types.py @@ -159,4 +159,8 @@ class Diagram: @dataclass class StringData: + data: str + +@dataclass +class DataFrameJSON: data: str \ No newline at end of file