Skip to content

Commit

Permalink
get_all_questions
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda committed Jul 19, 2023
1 parent c611af2 commit fc793e1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 33 deletions.
42 changes: 9 additions & 33 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import plotly
import plotly.express as px
import plotly.graph_objects as go
from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember
from .types import SQLAnswer, Explanation, QuestionSQLPair, Question, QuestionId, DataResult, PlotlyResult, Status, FullQuestionDocument, QuestionList, QuestionCategory, AccuracyStats, UserEmail, UserOTP, ApiKey, OrganizationList, Organization, NewOrganization, StringData, QuestionStringList, Visibility, NewOrganizationMember, DataFrameJSON
from typing import List, Dict, Any, Union, Optional, Callable, Tuple
import warnings
import traceback
Expand Down Expand Up @@ -783,55 +783,31 @@ def generate_question(sql: str) -> str:

return question.question

def get_flagged_questions() -> QuestionList:
def get_all_questions() -> pd.DataFrame:
"""
## Example
```python
questions = vn.get_flagged_questions()
questions = vn.get_all_questions()
```
Get a list of flagged questions from the Vanna.AI API.
Get a list of questions from the Vanna.AI API.
Returns:
List[FullQuestionDocument] or None: The list of flagged questions, or None if an error occurred.
pd.DataFrame or None: The list of questions, or None if an error occurred.
"""
# params = [Question(question="")]
params = []

d = __rpc_call(method="get_flagged_questions", params=params)
d = __rpc_call(method="get_all_questions", params=params)

if 'result' not in d:
return None

# Load the result into a dataclass
flagged_questions = QuestionList(**d['result'])
all_questions = DataFrameJSON(**d['result'])

return flagged_questions
df = pd.read_json(all_questions.data)

def get_accuracy_stats() -> AccuracyStats:
"""
## Example
```python
vn.get_accuracy_stats()
```
Get the accuracy statistics from the Vanna.AI API.
Returns:
dict or None: The accuracy statistics, or None if an error occurred.
"""
params = []

d = __rpc_call(method="get_accuracy_stats", params=params)

if 'result' not in d:
return None

# Load the result into a dataclass
accuracy_stats = AccuracyStats(**d['result'])

return accuracy_stats
return df
4 changes: 4 additions & 0 deletions src/vanna/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,8 @@ class Diagram:

@dataclass
class StringData:
data: str

@dataclass
class DataFrameJSON:
data: str

0 comments on commit fc793e1

Please sign in to comment.