Skip to content

Commit

Permalink
Retry tavily outside of model parsing and allow to not generate subqu…
Browse files Browse the repository at this point in the history
…eries (#104)
  • Loading branch information
kongzii authored Aug 1, 2024
1 parent ba98871 commit 0564033
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
3 changes: 3 additions & 0 deletions prediction_prophet/functions/generate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
Limit your searches to {search_limit}.
"""
def generate_subqueries(query: str, limit: int, model: str, api_key: SecretStr | None = None) -> list[str]:
if limit == 0:
return [query]

if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")

Expand Down
34 changes: 21 additions & 13 deletions prediction_prophet/functions/web_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import tenacity
import typing as t
from tavily import TavilyClient
from pydantic.types import SecretStr

Expand All @@ -8,19 +8,9 @@
from prediction_prophet.functions.cache import persistent_inmemory_cache


@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True)
@persistent_inmemory_cache

def web_search(query: str, max_results: int = 5, tavily_api_key: SecretStr | None = None) -> list[WebSearchResult]:
if tavily_api_key == None:
tavily_api_key = secret_str_from_env("TAVILY_API_KEY")

tavily = TavilyClient(api_key=tavily_api_key.get_secret_value() if tavily_api_key else None)
response = tavily.search(
query=query,
search_depth="advanced",
max_results=max_results,
include_raw_content=True,
)
response = _web_search(query=query, max_results=max_results, tavily_api_key=tavily_api_key)

transformed_results = [
WebSearchResult(
Expand All @@ -35,3 +25,21 @@ def web_search(query: str, max_results: int = 5, tavily_api_key: SecretStr | Non
]

return transformed_results


@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True)
@persistent_inmemory_cache
def _web_search(query: str, max_results: int = 5, tavily_api_key: SecretStr | None = None) -> dict[str, t.Any]:
if tavily_api_key == None:
tavily_api_key = secret_str_from_env("TAVILY_API_KEY")

tavily = TavilyClient(api_key=tavily_api_key.get_secret_value() if tavily_api_key else None)
response: dict[str, t.Any] = tavily.search(
query=query,
search_depth="advanced",
max_results=max_results,
include_raw_content=True,
)

return response

0 comments on commit 0564033

Please sign in to comment.