Skip to content

Commit

Permalink
Squash 4 commits into one
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii committed Oct 31, 2024
1 parent aff54ca commit 906f1f2
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 21 deletions.
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions prediction_prophet/autonolas/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import math
import tenacity
from datetime import timedelta
from sklearn.metrics.pairwise import cosine_similarity
from typing import Any, Dict, Generator, List, Optional, Tuple, TypedDict
from datetime import datetime, timezone
Expand Down Expand Up @@ -32,7 +33,7 @@
from dateutil import parser
from prediction_prophet.functions.utils import check_not_none
from prediction_market_agent_tooling.gtypes import Probability
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.caches.db_cache import db_cache
from prediction_prophet.functions.parallelism import par_map
from pydantic.types import SecretStr
Expand Down Expand Up @@ -358,7 +359,7 @@ def fields_dict_to_bullet_list(fields_dict: Dict[str, str]) -> str:
return bullet_list

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True)
@db_cache
@db_cache(max_age=timedelta(days=1))
def search_google(query: str, num: int = 3) -> List[str]:
"""Search Google using a custom search engine."""
service = build("customsearch", "v1", developerKey=os.getenv("GOOGLE_SEARCH_API_KEY"))
Expand Down Expand Up @@ -1220,7 +1221,7 @@ def make_prediction(
api_key: SecretStr | None = None,
) -> Prediction:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

current_time_utc = datetime.now(timezone.utc)
formatted_time_utc = current_time_utc.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-6] + "Z"
Expand Down
2 changes: 1 addition & 1 deletion prediction_prophet/benchmark/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from prediction_prophet.autonolas.research import EmbeddingModel
from prediction_prophet.autonolas.research import make_prediction, get_urls_from_queries
from prediction_prophet.autonolas.research import research as research_autonolas
from prediction_market_agent_tooling.tools.is_predictable import is_predictable_binary
from prediction_prophet.functions.rephrase_question import rephrase_question
from prediction_prophet.functions.research import Research, research as prophet_research
from prediction_prophet.functions.search import search
Expand All @@ -26,6 +25,7 @@
from pydantic.types import SecretStr
from prediction_prophet.autonolas.research import Prediction as LLMCompletionPredictionDict
from prediction_market_agent_tooling.tools.langfuse_ import observe
from prediction_market_agent_tooling.tools.is_predictable import is_predictable_binary

if t.TYPE_CHECKING:
from loguru import Logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from prediction_prophet.models.WebScrapeResult import WebScrapeResult
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


def create_embeddings_from_results(results: list[WebScrapeResult], text_splitter: RecursiveCharacterTextSplitter, api_key: SecretStr | None = None) -> Chroma:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

collection = Chroma(embedding_function=OpenAIEmbeddings(api_key=secretstr_to_v1_secretstr(api_key)))
texts = []
Expand Down
4 changes: 2 additions & 2 deletions prediction_prophet/functions/debate_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.schema.output_parser import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr


Expand Down Expand Up @@ -85,7 +85,7 @@

def make_debated_prediction(prompt: str, additional_information: str, api_key: SecretStr | None = None) -> Prediction:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

formatted_time_utc = datetime.datetime.now(datetime.timezone.utc).isoformat(timespec='seconds') + "Z"

Expand Down
4 changes: 2 additions & 2 deletions prediction_prophet/functions/generate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import ChatPromptTemplate
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

Expand All @@ -22,7 +22,7 @@ def generate_subqueries(query: str, limit: int, model: str, temperature: float,
return [query]

if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

subquery_generation_prompt = ChatPromptTemplate.from_template(template=subquery_generation_template)

Expand Down
6 changes: 3 additions & 3 deletions prediction_prophet/functions/prepare_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from prediction_prophet.functions.utils import trim_to_n_tokens
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.config import APIKeys
from pydantic.types import SecretStr
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

@observe()
def prepare_summary(goal: str, content: str, model: str, api_key: SecretStr | None = None, trim_content_to_tokens: t.Optional[int] = None) -> str:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

prompt_template = """Write comprehensive summary of the following web content, that provides relevant information to answer the question: '{goal}'.
But cut the fluff and keep it up to the point.
Expand Down Expand Up @@ -43,7 +43,7 @@ def prepare_summary(goal: str, content: str, model: str, api_key: SecretStr | No
@observe()
def prepare_report(goal: str, scraped: list[str], model: str, temperature: float, api_key: SecretStr | None = None) -> str:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

evaluation_prompt_template = """
You are a professional researcher. Your goal is to provide a relevant information report
Expand Down
4 changes: 2 additions & 2 deletions prediction_prophet/functions/rerank_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

Expand All @@ -20,7 +20,7 @@
@observe()
def rerank_subqueries(queries: list[str], goal: str, model: str, temperature: float, api_key: SecretStr | None = None) -> list[str]:
if api_key == None:
api_key = secret_str_from_env("OPENAI_API_KEY")
api_key = APIKeys().openai_api_key

rerank_results_prompt = ChatPromptTemplate.from_template(template=rerank_queries_template)

Expand Down
3 changes: 2 additions & 1 deletion prediction_prophet/functions/web_scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
from bs4 import BeautifulSoup
from requests import Response
import tenacity
from datetime import timedelta
from prediction_market_agent_tooling.tools.caches.db_cache import db_cache


@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True)
@db_cache
def fetch_html(url: str, timeout: int) -> Response:
headers = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:107.0) Gecko/20100101 Firefox/107.0"
}
response = requests.get(url, headers=headers, timeout=timeout)
return response

@db_cache(max_age=timedelta(days=1))
def web_scrape_strict(url: str, timeout: int = 10) -> str:
response = fetch_html(url=url, timeout=timeout)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ scikit-learn = "^1.4.0"
typer = ">=0.9.0,<1.0.0"
types-requests = "^2.31.0.20240125"
types-python-dateutil = "^2.9.0"
prediction-market-agent-tooling = { version = "^0.56.0.dev112", extras = ["langchain", "google"] }
prediction-market-agent-tooling = { version = "^0.56.0.dev113", extras = ["langchain", "google"] }
langchain-community = "^0.2.6"
memory-profiler = "^0.61.0"
matplotlib = "^3.8.3"
Expand Down

0 comments on commit 906f1f2

Please sign in to comment.