diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py index 27e2ec3..e862724 100644 --- a/knowledge_storm/__init__.py +++ b/knowledge_storm/__init__.py @@ -1,7 +1,7 @@ from .storm_wiki.engine import ( STORMWikiLMConfigs, STORMWikiRunnerArguments, - STORMWikiRunner + STORMWikiRunner, ) __version__ = "0.2.5" diff --git a/knowledge_storm/interface.py b/knowledge_storm/interface.py index 03df2fb..f6c11bd 100644 --- a/knowledge_storm/interface.py +++ b/knowledge_storm/interface.py @@ -5,7 +5,9 @@ from collections import OrderedDict from typing import Dict, List, Optional, Union -logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s" +) logger = logging.getLogger(__name__) @@ -70,7 +72,9 @@ class Article(ABC): def __init__(self, topic_name): self.root = ArticleSectionNode(topic_name) - def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: + def find_section( + self, node: ArticleSectionNode, name: str + ) -> Optional[ArticleSectionNode]: """ Return the node of the section given the section name. @@ -152,7 +156,9 @@ def prune_empty_nodes(self, node=None): if node is None: node = self.root - node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)] + node.children[:] = [ + child for child in node.children if self.prune_empty_nodes(child) + ] if (node.content is None or node.content == "") and not node.children: return None @@ -178,7 +184,9 @@ def update_search_top_k(self, k): def collect_and_reset_rm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): + if "_rm" in attr_name and hasattr( + getattr(self, attr_name), "get_usage_and_reset" + ): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) name_to_usage = {} @@ -240,7 +248,9 @@ class OutlineGenerationModule(ABC): """ @abstractmethod - def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article: + def generate_outline( + self, topic: str, information_table: InformationTable, **kwargs + ) -> Article: """ Generate outline for the article. Required arguments include: topic: the topic of interest @@ -263,11 +273,13 @@ class ArticleGenerationModule(ABC): """ @abstractmethod - def generate_article(self, - topic: str, - information_table: InformationTable, - article_with_outline: Article, - **kwargs) -> Article: + def generate_article( + self, + topic: str, + information_table: InformationTable, + article_with_outline: Article, + **kwargs, + ) -> Article: """ Generate article. Required arguments include: topic: the topic of interest @@ -312,14 +324,15 @@ def wrapper(self, *args, **kwargs): class LMConfigs(ABC): """Abstract base class for language model configurations of the knowledge curation engine. - The language model used for each part should be declared with a suffix '_lm' in the attribute name.""" + The language model used for each part should be declared with a suffix '_lm' in the attribute name. + """ def __init__(self): pass def init_check(self): for attr_name in self.__dict__: - if '_lm' in attr_name and getattr(self, attr_name) is None: + if "_lm" in attr_name and getattr(self, attr_name) is None: logging.warning( f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()" ) @@ -327,7 +340,7 @@ def init_check(self): def collect_and_reset_lm_history(self): history = [] for attr_name in self.__dict__: - if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'): + if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"): history.extend(getattr(self, attr_name).history) getattr(self, attr_name).history = [] @@ -336,7 +349,9 @@ def collect_and_reset_lm_history(self): def collect_and_reset_lm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): + if "_lm" in attr_name and hasattr( + getattr(self, attr_name), "get_usage_and_reset" + ): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) model_name_to_usage = {} @@ -345,8 +360,12 @@ def collect_and_reset_lm_usage(self): if model_name not in model_name_to_usage: model_name_to_usage[model_name] = tokens else: - model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens'] - model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens'] + model_name_to_usage[model_name]["prompt_tokens"] += tokens[ + "prompt_tokens" + ] + model_name_to_usage[model_name]["completion_tokens"] += tokens[ + "completion_tokens" + ] return model_name_to_usage @@ -354,8 +373,9 @@ def log(self): return OrderedDict( { - attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if - '_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs') + attr_name: getattr(self, attr_name).kwargs + for attr_name in self.__dict__ + if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs") } ) @@ -379,16 +399,21 @@ def wrapper(*args, **kwargs): self.time[func.__name__] = execution_time logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds") self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage() - if hasattr(self, 'retriever'): - self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage() + if hasattr(self, "retriever"): + self.rm_cost[func.__name__] = ( + self.retriever.collect_and_reset_rm_usage() + ) return result return wrapper def apply_decorators(self): """Apply decorators to methods that need them.""" - methods_to_decorate = [method_name for method_name in dir(self) - if callable(getattr(self, method_name)) and method_name.startswith('run_')] + methods_to_decorate = [ + method_name + for method_name in dir(self) + if callable(getattr(self, method_name)) and method_name.startswith("run_") + ] for method_name in methods_to_decorate: original_method = getattr(self, method_name) decorated_method = self.log_execution_time_and_lm_rm_usage(original_method) diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index 520d530..2c0773d 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -23,11 +23,11 @@ class OpenAIModel(dspy.OpenAI): """A wrapper class for dspy.OpenAI.""" def __init__( - self, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = None, - **kwargs + self, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = None, + **kwargs, ): super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() @@ -36,17 +36,20 @@ def __init__( def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get('model') or self.kwargs.get('engine'): - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.kwargs.get("model") + or self.kwargs.get("engine"): { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -54,11 +57,11 @@ def get_usage_and_reset(self): return usage def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.""" @@ -110,11 +113,11 @@ class DeepSeekModel(dspy.OpenAI): """A wrapper class for DeepSeek API, compatible with dspy.OpenAI.""" def __init__( - self, - model: str = "deepseek-chat", - api_key: Optional[str] = None, - api_base: str = "https://api.deepseek.com", - **kwargs + self, + model: str = "deepseek-chat", + api_key: Optional[str] = None, + api_base: str = "https://api.deepseek.com", + **kwargs, ): super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs) self._token_usage_lock = threading.Lock() @@ -125,21 +128,24 @@ def __init__( self.api_base = api_base if not self.api_key: raise ValueError( - "DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY") + "DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY" + ) def log_usage(self, response): """Log the total tokens from the DeepSeek API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -156,23 +162,25 @@ def _create_completion(self, prompt: str, **kwargs): """Create a completion using the DeepSeek API.""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", } data = { "model": self.model, "messages": [{"role": "user", "content": prompt}], - **kwargs + **kwargs, } - response = requests.post(f"{self.api_base}/v1/chat/completions", headers=headers, json=data) + response = requests.post( + f"{self.api_base}/v1/chat/completions", headers=headers, json=data + ) response.raise_for_status() return response.json() def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Call the DeepSeek API to generate completions.""" assert only_completed, "for now" @@ -200,34 +208,44 @@ class AzureOpenAIModel(dspy.AzureOpenAI): """A wrapper class for dspy.AzureOpenAI.""" def __init__( - self, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = "chat", - **kwargs, + self, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = "chat", + **kwargs, ): super().__init__( - api_base=api_base, api_version=api_version, model=model, api_key=api_key, model_type=model_type, **kwargs) + api_base=api_base, + api_version=api_version, + model=model, + api_key=api_key, + model_type=model_type, + **kwargs, + ) self._token_usage_lock = threading.Lock() self.prompt_tokens = 0 self.completion_tokens = 0 def log_usage(self, response): """Log the total tokens from the OpenAI API response. - Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage.""" - usage_data = response.get('usage') + Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage. + """ + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get('model') or self.kwargs.get('engine'): - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.kwargs.get("model") + or self.kwargs.get("engine"): { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -239,11 +257,11 @@ class GroqModel(dspy.OpenAI): """A wrapper class for Groq API (https://console.groq.com/), compatible with dspy.OpenAI.""" def __init__( - self, - model: str = "llama3-70b-8192", - api_key: Optional[str] = None, - api_base: str = "https://api.groq.com/openai/v1", - **kwargs + self, + model: str = "llama3-70b-8192", + api_key: Optional[str] = None, + api_base: str = "https://api.groq.com/openai/v1", + **kwargs, ): super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs) self._token_usage_lock = threading.Lock() @@ -254,21 +272,24 @@ def __init__( self.api_base = api_base if not self.api_key: raise ValueError( - "Groq API key must be provided either as an argument or as an environment variable GROQ_API_KEY") + "Groq API key must be provided either as an argument or as an environment variable GROQ_API_KEY" + ) def log_usage(self, response): """Log the total tokens from the Groq API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -285,42 +306,44 @@ def _create_completion(self, prompt: str, **kwargs): """Create a completion using the Groq API.""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", } # Remove unsupported fields - kwargs.pop('logprobs', None) - kwargs.pop('logit_bias', None) - kwargs.pop('top_logprobs', None) + kwargs.pop("logprobs", None) + kwargs.pop("logit_bias", None) + kwargs.pop("top_logprobs", None) # Ensure N is 1 if supplied - if 'n' in kwargs and kwargs['n'] != 1: + if "n" in kwargs and kwargs["n"] != 1: raise ValueError("Groq API only supports N=1") # Adjust temperature if it's 0 - if kwargs.get('temperature', 1) == 0: - kwargs['temperature'] = 1e-8 + if kwargs.get("temperature", 1) == 0: + kwargs["temperature"] = 1e-8 data = { "model": self.model, "messages": [{"role": "user", "content": prompt}], - **kwargs + **kwargs, } # Remove 'name' field from messages if present - for message in data['messages']: - message.pop('name', None) + for message in data["messages"]: + message.pop("name", None) - response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=data) + response = requests.post( + f"{self.api_base}/chat/completions", headers=headers, json=data + ) response.raise_for_status() return response.json() def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Call the Groq API to generate completions.""" assert only_completed, "for now" @@ -348,11 +371,11 @@ class ClaudeModel(dspy.dsp.modules.lm.LM): """Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.""" def __init__( - self, - model: str, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - **kwargs, + self, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs, ): super().__init__(model) try: @@ -361,12 +384,21 @@ def __init__( raise ImportError("Claude requires `pip install anthropic`.") from err self.provider = "anthropic" - self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key - self.api_base = "https://api.anthropic.com/v1/messages" if api_base is None else api_base - self.kwargs = {"temperature": kwargs.get("temperature", 0.0), - "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), "top_p": kwargs.get("top_p", 1.0), - "top_k": kwargs.get("top_k", 1), "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), - **kwargs, "model": model} + self.api_key = api_key = ( + os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key + ) + self.api_base = ( + "https://api.anthropic.com/v1/messages" if api_base is None else api_base + ) + self.kwargs = { + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), + "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", 1), + "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), + **kwargs, + "model": model, + } self.history: list[dict[str, Any]] = [] self.client = Anthropic(api_key=api_key) self.model = model @@ -386,8 +418,10 @@ def log_usage(self, response): def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -419,7 +453,7 @@ def basic_request(self, prompt: str, **kwargs): "usage": { "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, - } + }, }, "kwargs": kwargs, "raw_kwargs": raw_kwargs, @@ -475,8 +509,15 @@ class VLLMClient(dspy.dsp.LM): vLLM HTTP server is designed to be compatible with the OpenAI API. Use OpenAI client to interact with the server. """ - def __init__(self, model, port, model_type: Literal["chat", "text"] = "text", url="http://localhost", - api_key="null", **kwargs): + def __init__( + self, + model, + port, + model_type: Literal["chat", "text"] = "text", + url="http://localhost", + api_key="null", + **kwargs, + ): """Check out https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html for more information.""" super().__init__(model=model) # Store additional kwargs for the generate method. @@ -517,8 +558,11 @@ def log_usage(self, response): def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get('model') or self.kwargs.get('engine'): - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.kwargs.get("model") + or self.kwargs.get("engine"): { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -564,7 +608,13 @@ def __init__(self, model, port, url="http://localhost", **kwargs): class TGIClient(dspy.HFClientTGI): def __init__(self, model, port, url, http_request_kwargs=None, **kwargs): - super().__init__(model=model, port=port, url=url, http_request_kwargs=http_request_kwargs, **kwargs) + super().__init__( + model=model, + port=port, + url=url, + http_request_kwargs=http_request_kwargs, + **kwargs, + ) def _generate(self, prompt, **kwargs): """Copied from dspy/dsp/modules/hf_client.py with the addition of removing hard-coded parameters.""" @@ -603,8 +653,8 @@ def _generate(self, prompt, **kwargs): completions = [json_response["generated_text"]] if ( - "details" in json_response - and "best_of_sequences" in json_response["details"] + "details" in json_response + and "best_of_sequences" in json_response["details"] ): completions += [ x["generated_text"] @@ -621,13 +671,22 @@ def _generate(self, prompt, **kwargs): class TogetherClient(dspy.HFModel): """A wrapper class for dspy.Together.""" - def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name=None, **kwargs): + def __init__( + self, + model, + apply_tokenizer_chat_template=False, + hf_tokenizer_name=None, + **kwargs, + ): """Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template.""" super().__init__(model=model, is_client=True) self.session = requests.Session() - self.api_base = "https://api.together.xyz/v1/completions" if os.getenv( - "TOGETHER_API_BASE") is None else os.getenv("TOGETHER_API_BASE") + self.api_base = ( + "https://api.together.xyz/v1/completions" + if os.getenv("TOGETHER_API_BASE") is None + else os.getenv("TOGETHER_API_BASE") + ) self.token = os.getenv("TOGETHER_API_KEY") self.model = model @@ -639,7 +698,9 @@ def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name logging.info("Loading huggingface tokenizer.") if hf_tokenizer_name is None: hf_tokenizer_name = self.model - self.tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None)) + self.tokenizer = AutoTokenizer.from_pretrained( + hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None) + ) stop_default = "\n\n---" @@ -659,17 +720,19 @@ def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get('usage') + usage_data = response.get("usage") if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get('prompt_tokens', 0) - self.completion_tokens += usage_data.get('completion_tokens', 0) + self.prompt_tokens += usage_data.get("prompt_tokens", 0) + self.completion_tokens += usage_data.get("completion_tokens", 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -694,14 +757,18 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): top_k = kwargs.get("top_k", 50) repetition_penalty = kwargs.get("repetition_penalty", 1) if self.apply_tokenizer_chat_template: - prompt = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False) + prompt = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], tokenize=False + ) # prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt if use_chat_api: url = f"{self.api_base}/chat/completions" messages = [ - {"role": "system", - "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, + { + "role": "system", + "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections.", + }, {"role": "user", "content": prompt}, ] body = { @@ -734,10 +801,14 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): self.log_usage(resp_json) if use_chat_api: # completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")] - completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] + completions = [ + resp_json.get("choices", [])[0] + .get("message", {}) + .get("content", "") + ] else: # completions = [resp_json['output'].get('choices', [])[0].get('text', "")] - completions = [resp_json.get('choices', [])[0].get('text', "")] + completions = [resp_json.get("choices", [])[0].get("text", "")] response = {"prompt": prompt, "choices": [{"text": c} for c in completions]} return response @@ -746,24 +817,28 @@ class GoogleModel(dspy.dsp.modules.lm.LM): """A wrapper class for Google Gemini API.""" def __init__( - self, - model: str, - api_key: Optional[str] = None, - **kwargs, + self, + model: str, + api_key: Optional[str] = None, + **kwargs, ): """You can use `genai.list_models()` to get a list of available models.""" super().__init__(model) try: import google.generativeai as genai except ImportError as err: - raise ImportError("GoogleModel requires `pip install google-generativeai`.") from err + raise ImportError( + "GoogleModel requires `pip install google-generativeai`." + ) from err api_key = os.environ.get("GOOGLE_API_KEY") if api_key is None else api_key genai.configure(api_key=api_key) kwargs = { "candidate_count": 1, # Caveat: Gemini API supports only one candidate for now. - "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], + "temperature": ( + 0.0 if "temperature" not in kwargs else kwargs["temperature"] + ), "max_output_tokens": kwargs["max_tokens"], "top_p": 1, "top_k": 1, @@ -774,8 +849,9 @@ def __init__( self.model = model self.config = genai.GenerationConfig(**kwargs) - self.llm = genai.GenerativeModel(model_name=model, - generation_config=self.config) + self.llm = genai.GenerativeModel( + model_name=model, generation_config=self.config + ) self.kwargs = { "n": 1, @@ -799,8 +875,10 @@ def log_usage(self, response): def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: - {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + self.model: { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + } } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -842,11 +920,11 @@ def request(self, prompt: str, **kwargs): return self.basic_request(prompt, **kwargs) def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ): assert only_completed, "for now" assert return_sorted is False, "for now" diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index c84c42e..7d4d696 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -15,7 +15,9 @@ class YouRM(dspy.Retrieve): def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None): super().__init__(k=k) if not ydc_api_key and not os.environ.get("YDC_API_KEY"): - raise RuntimeError("You must supply ydc_api_key or set environment variable YDC_API_KEY") + raise RuntimeError( + "You must supply ydc_api_key or set environment variable YDC_API_KEY" + ) elif ydc_api_key: self.ydc_api_key = ydc_api_key else: @@ -32,9 +34,11 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'YouRM': usage} + return {"YouRM": usage} - def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): """Search with You.com for self.k top passages for query or queries Args: @@ -60,21 +64,30 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st ).json() authoritative_results = [] - for r in results['hits']: - if self.is_valid_source(r['url']) and r['url'] not in exclude_urls: + for r in results["hits"]: + if self.is_valid_source(r["url"]) and r["url"] not in exclude_urls: authoritative_results.append(r) - if 'hits' in results: - collected_results.extend(authoritative_results[:self.k]) + if "hits" in results: + collected_results.extend(authoritative_results[: self.k]) except Exception as e: - logging.error(f'Error occurs when searching query {query}: {e}') + logging.error(f"Error occurs when searching query {query}: {e}") return collected_results class BingSearch(dspy.Retrieve): - def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None, - min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10, - mkt='en-US', language='en', **kwargs): + def __init__( + self, + bing_search_api_key=None, + k=3, + is_valid_source: Callable = None, + min_char_count: int = 150, + snippet_chunk_size: int = 1000, + webpage_helper_max_threads=10, + mkt="en-US", + language="en", + **kwargs, + ): """ Params: min_char_count: Minimum character count for the article to be considered valid. @@ -86,22 +99,18 @@ def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = No super().__init__(k=k) if not bing_search_api_key and not os.environ.get("BING_SEARCH_API_KEY"): raise RuntimeError( - "You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY") + "You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY" + ) elif bing_search_api_key: self.bing_api_key = bing_search_api_key else: self.bing_api_key = os.environ["BING_SEARCH_API_KEY"] self.endpoint = "https://api.bing.microsoft.com/v7.0/search" - self.params = { - 'mkt': mkt, - "setLang": language, - "count": k, - **kwargs - } + self.params = {"mkt": mkt, "setLang": language, "count": k, **kwargs} self.webpage_helper = WebPageHelper( min_char_count=min_char_count, snippet_chunk_size=snippet_chunk_size, - max_thread_num=webpage_helper_max_threads + max_thread_num=webpage_helper_max_threads, ) self.usage = 0 @@ -115,9 +124,11 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'BingSearch': usage} + return {"BingSearch": usage} - def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): """Search with Bing for self.k top passages for query or queries Args: @@ -141,22 +152,26 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st for query in queries: try: results = requests.get( - self.endpoint, - headers=headers, - params={**self.params, 'q': query} + self.endpoint, headers=headers, params={**self.params, "q": query} ).json() - for d in results['webPages']['value']: - if self.is_valid_source(d['url']) and d['url'] not in exclude_urls: - url_to_results[d['url']] = {'url': d['url'], 'title': d['name'], 'description': d['snippet']} + for d in results["webPages"]["value"]: + if self.is_valid_source(d["url"]) and d["url"] not in exclude_urls: + url_to_results[d["url"]] = { + "url": d["url"], + "title": d["name"], + "description": d["snippet"], + } except Exception as e: - logging.error(f'Error occurs when searching query {query}: {e}') + logging.error(f"Error occurs when searching query {query}: {e}") - valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys())) + valid_url_to_snippets = self.webpage_helper.urls_to_snippets( + list(url_to_results.keys()) + ) collected_results = [] for url in valid_url_to_snippets: r = url_to_results[url] - r['snippets'] = valid_url_to_snippets[url]['snippets'] + r["snippets"] = valid_url_to_snippets[url]["snippets"] collected_results.append(r) return collected_results @@ -174,12 +189,13 @@ class VectorRM(dspy.Retrieve): The documents should be stored in a CSV file. """ - def __init__(self, - collection_name: str, - embedding_model: str, - device: str = "mps", - k: int = 3, - ): + def __init__( + self, + collection_name: str, + embedding_model: str, + device: str = "mps", + k: int = 3, + ): """ Params: collection_name: Name of the Qdrant collection. @@ -199,7 +215,9 @@ def __init__(self, model_kwargs = {"device": device} encode_kwargs = {"normalize_embeddings": True} self.model = HuggingFaceEmbeddings( - model_name=embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs + model_name=embedding_model, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs, ) self.collection_name = collection_name @@ -213,14 +231,18 @@ def _check_collection(self): if self.client is None: raise ValueError("Qdrant client is not initialized.") if self.client.collection_exists(collection_name=f"{self.collection_name}"): - print(f"Collection {self.collection_name} exists. Loading the collection...") + print( + f"Collection {self.collection_name} exists. Loading the collection..." + ) self.qdrant = Qdrant( client=self.client, collection_name=self.collection_name, embeddings=self.model, ) else: - raise ValueError(f"Collection {self.collection_name} does not exist. Please create the collection first.") + raise ValueError( + f"Collection {self.collection_name} does not exist. Please create the collection first." + ) def init_online_vector_db(self, url: str, api_key: str): """ @@ -263,7 +285,7 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'VectorRM': usage} + return {"VectorRM": usage} def get_vector_count(self): """ @@ -296,12 +318,14 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st related_docs = self.qdrant.similarity_search_with_score(query, k=self.k) for i in range(len(related_docs)): doc = related_docs[i][0] - collected_results.append({ - 'description': doc.metadata['description'], - 'snippets': [doc.page_content], - 'title': doc.metadata['title'], - 'url': doc.metadata['url'], - }) + collected_results.append( + { + "description": doc.metadata["description"], + "snippets": [doc.page_content], + "title": doc.metadata["title"], + "url": doc.metadata["url"], + } + ) return collected_results @@ -311,55 +335,55 @@ class SerperRM(dspy.Retrieve): def __init__(self, serper_search_api_key=None, query_params=None): """Args: - serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/ - query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. - Commonly used fields are as follows (see more information in https://serper.dev/playground): - q str: query that will be used with google search - type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc. - gl str: Country that will be focused on for the search - location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations. - autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated. - results int: Max number of results per page. - page int: Max number of pages per call. - tbs str: date time range, automatically set to any time by default. - qdr:h str: Date time range for the past hour. - qdr:d str: Date time range for the past 24 hours. - qdr:w str: Date time range for past week. - qdr:m str: Date time range for past month. - qdr:y str: Date time range for past year. + serper_search_api_key str: API key to run serper, can be found by creating an account on https://serper.dev/ + query_params (dict or list of dict): parameters in dictionary or list of dictionaries that has a max size of 100 that will be used to query. + Commonly used fields are as follows (see more information in https://serper.dev/playground): + q str: query that will be used with google search + type str: type that will be used for browsing google. Types are search, images, video, maps, places, etc. + gl str: Country that will be focused on for the search + location str: Country where the search will originate from. All locates can be found here: https://api.serper.dev/locations. + autocorrect bool: Enable autocorrect on the queries while searching, if query is misspelled, will be updated. + results int: Max number of results per page. + page int: Max number of pages per call. + tbs str: date time range, automatically set to any time by default. + qdr:h str: Date time range for the past hour. + qdr:d str: Date time range for the past 24 hours. + qdr:w str: Date time range for past week. + qdr:m str: Date time range for past month. + qdr:y str: Date time range for past year. """ super().__init__() self.usage = 0 self.query_params = query_params self.serper_search_api_key = serper_search_api_key - if not self.serper_search_api_key and not os.environ.get('SERPER_API_KEY'): + if not self.serper_search_api_key and not os.environ.get("SERPER_API_KEY"): raise RuntimeError( - 'You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY' + "You must supply a serper_search_api_key param or set environment variable SERPER_API_KEY" ) elif self.serper_search_api_key: self.serper_search_api_key = serper_search_api_key else: - self.serper_search_api_key = os.environ['SERPER_API_KEY'] + self.serper_search_api_key = os.environ["SERPER_API_KEY"] - self.base_url = 'https://google.serper.dev' + self.base_url = "https://google.serper.dev" def serper_runner(self, query_params): - self.search_url = f'{self.base_url}/search' + self.search_url = f"{self.base_url}/search" headers = { - 'X-API-KEY': self.serper_search_api_key, - 'Content-Type': 'application/json', + "X-API-KEY": self.serper_search_api_key, + "Content-Type": "application/json", } response = requests.request( - 'POST', self.search_url, headers=headers, json=query_params + "POST", self.search_url, headers=headers, json=query_params ) if response == None: raise RuntimeError( - f'Error had occured while running the search process.\n Error is {response.reason}, had failed with status code {response.status_code}' + f"Error had occured while running the search process.\n Error is {response.reason}, had failed with status code {response.status_code}" ) return response.json() @@ -367,7 +391,7 @@ def serper_runner(self, query_params): def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'SerperRM': usage} + return {"SerperRM": usage} def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]): """ @@ -391,16 +415,16 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st self.results = [] collected_results = [] for query in queries: - if query == 'Queries:': + if query == "Queries:": continue query_params = self.query_params # All available parameters can be found in the playground: https://serper.dev/playground # Sets the json value for query to be the query that is being parsed. - query_params['q'] = query + query_params["q"] = query # Sets the type to be search, can be images, video, places, maps etc that Google provides. - query_params['type'] = 'search' + query_params["type"] = "search" self.result = self.serper_runner(query_params) self.results.append(self.result) @@ -411,29 +435,29 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st for result in self.results: try: # An array of dictionaries that contains the snippets, title of the document and url that will be used. - organic_results = result.get('organic') + organic_results = result.get("organic") - knowledge_graph = result.get('knowledgeGraph') + knowledge_graph = result.get("knowledgeGraph") for organic in organic_results: snippets = [] - snippets.append(organic.get('snippet')) + snippets.append(organic.get("snippet")) if knowledge_graph != None: collected_results.append( { - 'snippets': snippets, - 'title': organic.get('title'), - 'url': organic.get('link'), - 'description': knowledge_graph.get('description'), + "snippets": snippets, + "title": organic.get("title"), + "url": organic.get("link"), + "description": knowledge_graph.get("description"), } ) else: # Common for knowledge graph to be None, set description to empty string collected_results.append( { - 'snippets': snippets, - 'title': organic.get('title'), - 'url': organic.get('link'), - 'description': '', + "snippets": snippets, + "title": organic.get("title"), + "url": organic.get("link"), + "description": "", } ) except: @@ -443,10 +467,14 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st class BraveRM(dspy.Retrieve): - def __init__(self, brave_search_api_key=None, k=3, is_valid_source: Callable = None): + def __init__( + self, brave_search_api_key=None, k=3, is_valid_source: Callable = None + ): super().__init__(k=k) if not brave_search_api_key and not os.environ.get("BRAVE_API_KEY"): - raise RuntimeError("You must supply brave_search_api_key or set environment variable BRAVE_API_KEY") + raise RuntimeError( + "You must supply brave_search_api_key or set environment variable BRAVE_API_KEY" + ) elif brave_search_api_key: self.brave_search_api_key = brave_search_api_key else: @@ -463,9 +491,11 @@ def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'BraveRM': usage} + return {"BraveRM": usage} - def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): """Search with api.search.brave.com for self.k top passages for query or queries Args: @@ -487,30 +517,37 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st headers = { "Accept": "application/json", "Accept-Encoding": "gzip", - "X-Subscription-Token": self.brave_search_api_key + "X-Subscription-Token": self.brave_search_api_key, } response = requests.get( f"https://api.search.brave.com/res/v1/web/search?result_filter=web&q={query}", headers=headers, ).json() - results = response.get('web', {}).get('results', []) + results = response.get("web", {}).get("results", []) for result in results: collected_results.append( { - 'snippets': result.get('extra_snippets', []), - 'title': result.get('title'), - 'url': result.get('url'), - 'description': result.get('description'), + "snippets": result.get("extra_snippets", []), + "title": result.get("title"), + "url": result.get("url"), + "description": result.get("description"), } ) except Exception as e: - logging.error(f'Error occurs when searching query {query}: {e}') + logging.error(f"Error occurs when searching query {query}: {e}") return collected_results + class SearXNG(dspy.Retrieve): - def __init__(self, searxng_api_url, searxng_api_key=None, k=3, is_valid_source: Callable = None): + def __init__( + self, + searxng_api_url, + searxng_api_key=None, + k=3, + is_valid_source: Callable = None, + ): """Initialize the SearXNG search retriever. Please set up SearXNG according to https://docs.searxng.org/index.html. @@ -536,9 +573,11 @@ def __init__(self, searxng_api_url, searxng_api_key=None, k=3, is_valid_source: def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'SearXNG': usage} + return {"SearXNG": usage} - def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): """Search with SearxNG for self.k top passages for query or queries Args: @@ -555,38 +594,48 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st ) self.usage += len(queries) collected_results = [] - headers = {"Authorization": f"Bearer {self.searxng_api_key}"} if self.searxng_api_key else {} + headers = ( + {"Authorization": f"Bearer {self.searxng_api_key}"} + if self.searxng_api_key + else {} + ) for query in queries: try: params = {"q": query, "format": "json"} - response = requests.get(self.searxng_api_url, headers=headers, params=params) + response = requests.get( + self.searxng_api_url, headers=headers, params=params + ) results = response.json() - for r in results['results']: - if self.is_valid_source(r['url']) and r['url'] not in exclude_urls: - collected_results.append({ - 'description': r.get('content', ''), - 'snippets': [r.get('content', '')], - 'title': r.get('title', ''), - 'url': r['url'] - }) + for r in results["results"]: + if self.is_valid_source(r["url"]) and r["url"] not in exclude_urls: + collected_results.append( + { + "description": r.get("content", ""), + "snippets": [r.get("content", "")], + "title": r.get("title", ""), + "url": r["url"], + } + ) except Exception as e: - logging.error(f'Error occurs when searching query {query}: {e}') + logging.error(f"Error occurs when searching query {query}: {e}") return collected_results + class DuckDuckGoSearchRM(dspy.Retrieve): """Retrieve information from custom queries using DuckDuckGo.""" + def __init__( self, - k: int =3, + k: int = 3, is_valid_source: Callable = None, min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10, - safe_search: str = 'On', - region: str = 'us-en' + safe_search: str = "On", + region: str = "us-en", ): """ Params: @@ -599,7 +648,9 @@ def __init__( try: from duckduckgo_search import DDGS except ImportError as err: - raise ImportError("Duckduckgo requires `pip install duckduckgo_search`.") from err + raise ImportError( + "Duckduckgo requires `pip install duckduckgo_search`." + ) from err self.k = k self.webpage_helper = WebPageHelper( min_char_count=min_char_count, @@ -607,11 +658,11 @@ def __init__( max_thread_num=webpage_helper_max_threads, ) self.usage = 0 - # All params for search can be found here: + # All params for search can be found here: # https://duckduckgo.com/duckduckgo-help-pages/settings/params/ # Sets the backend to be api - self.duck_duck_go_backend = 'api' + self.duck_duck_go_backend = "api" # Only gets safe search results self.duck_duck_go_safe_search = safe_search @@ -619,7 +670,6 @@ def __init__( # Specifies the region that the search will use self.duck_duck_go_region = region - # If not None, is_valid_source shall be a function that takes a URL and returns a boolean. if is_valid_source: self.is_valid_source = is_valid_source @@ -632,7 +682,7 @@ def __init__( def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'DuckDuckGoRM': usage} + return {"DuckDuckGoRM": usage} def forward( self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] @@ -662,37 +712,39 @@ def forward( for d in results: # assert d is dict if not isinstance(d, dict): - print(f'Invalid result: {d}\n') + print(f"Invalid result: {d}\n") continue try: # ensure keys are present - url = d.get('href', None) - title = d.get('title', None) - description = d.get('description', title) - snippets = [d.get('body', None)] + url = d.get("href", None) + title = d.get("title", None) + description = d.get("description", title) + snippets = [d.get("body", None)] # raise exception of missing key(s) if not all([url, title, description, snippets]): - raise ValueError(f'Missing key(s) in result: {d}') + raise ValueError(f"Missing key(s) in result: {d}") if self.is_valid_source(url) and url not in exclude_urls: result = { - 'url': url, - 'title': title, - 'description': description, - 'snippets': snippets, + "url": url, + "title": title, + "description": description, + "snippets": snippets, } collected_results.append(result) else: - print(f'invalid source {url} or url in exclude_urls') + print(f"invalid source {url} or url in exclude_urls") except Exception as e: - print(f'Error occurs when processing {result=}: {e}\n') - print(f'Error occurs when searching query {query}: {e}') + print(f"Error occurs when processing {result=}: {e}\n") + print(f"Error occurs when searching query {query}: {e}") return collected_results + class TavilySearchRM(dspy.Retrieve): """Retrieve information from custom queries using Tavily. Documentation and examples can be found at https://docs.tavily.com/docs/python-sdk/tavily-search/examples""" + def __init__( self, tavily_search_api_key=None, @@ -701,7 +753,7 @@ def __init__( min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10, - include_raw_content = False + include_raw_content=False, ): """ Params: @@ -716,9 +768,11 @@ def __init__( from tavily import TavilyClient except ImportError as err: raise ImportError("Tavily requires `pip install tavily-python`.") from err - + if not tavily_search_api_key and not os.environ.get("TAVILY_API_KEY"): - raise RuntimeError("You must supply tavily_search_api_key or set environment variable TAVILY_API_KEY") + raise RuntimeError( + "You must supply tavily_search_api_key or set environment variable TAVILY_API_KEY" + ) elif tavily_search_api_key: self.tavily_search_api_key = tavily_search_api_key else: @@ -733,12 +787,12 @@ def __init__( self.usage = 0 - # Creates client instance that will use search. Full search params are here: + # Creates client instance that will use search. Full search params are here: # https://docs.tavily.com/docs/python-sdk/tavily-search/examples self.tavily_client = TavilyClient(api_key=self.tavily_search_api_key) self.include_raw_content = include_raw_content - + # If not None, is_valid_source shall be a function that takes a URL and returns a boolean. if is_valid_source: self.is_valid_source = is_valid_source @@ -748,7 +802,7 @@ def __init__( def get_usage_and_reset(self): usage = self.usage self.usage = 0 - return {'TavilySearchRM': usage} + return {"TavilySearchRM": usage} def forward( self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] @@ -772,7 +826,7 @@ def forward( for query in queries: args = { "max_results": self.k, - "include_raw_contents": self.include_raw_content + "include_raw_contents": self.include_raw_content, } # list of dicts that will be parsed to return responseData = self.tavily_client.search(query) @@ -780,35 +834,35 @@ def forward( for d in results: # assert d is dict if not isinstance(d, dict): - print(f'Invalid result: {d}\n') + print(f"Invalid result: {d}\n") continue try: # ensure keys are present - url = d.get('url', None) - title = d.get('title', None) - description = d.get('content', None) + url = d.get("url", None) + title = d.get("title", None) + description = d.get("content", None) snippets = [] - if(d.get('raw_body_content')): - snippets.append(d.get('raw_body_content')) + if d.get("raw_body_content"): + snippets.append(d.get("raw_body_content")) else: - snippets.append(d.get('content')) + snippets.append(d.get("content")) # raise exception of missing key(s) if not all([url, title, description, snippets]): - raise ValueError(f'Missing key(s) in result: {d}') + raise ValueError(f"Missing key(s) in result: {d}") if self.is_valid_source(url) and url not in exclude_urls: result = { - 'url': url, - 'title': title, - 'description': description, - 'snippets': snippets, + "url": url, + "title": title, + "description": description, + "snippets": snippets, } collected_results.append(result) else: - print(f'invalid source {url} or url in exclude_urls') + print(f"invalid source {url} or url in exclude_urls") except Exception as e: - print(f'Error occurs when processing {result=}: {e}\n') - print(f'Error occurs when searching query {query}: {e}') + print(f"Error occurs when processing {result=}: {e}\n") + print(f"Error occurs when searching query {query}: {e}") - return collected_results \ No newline at end of file + return collected_results diff --git a/knowledge_storm/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py index 77361ad..de9f5f1 100644 --- a/knowledge_storm/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -28,43 +28,52 @@ class STORMWikiLMConfigs(LMConfigs): """ def __init__(self): - self.conv_simulator_lm = None # LLM used in conversation simulator except for question asking. + self.conv_simulator_lm = ( + None # LLM used in conversation simulator except for question asking. + ) self.question_asker_lm = None # LLM used in question asking. self.outline_gen_lm = None # LLM used in outline generation. self.article_gen_lm = None # LLM used in article generation. self.article_polish_lm = None # LLM used in article polishing. def init_openai_model( - self, - openai_api_key: str, - openai_type: Literal["openai", "azure"], - api_base: Optional[str] = None, - api_version: Optional[str] = None, - temperature: Optional[float] = 1.0, - top_p: Optional[float] = 0.9 + self, + openai_api_key: str, + openai_type: Literal["openai", "azure"], + api_base: Optional[str] = None, + api_version: Optional[str] = None, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 0.9, ): """Legacy: Corresponding to the original setup in the NAACL'24 paper.""" openai_kwargs = { - 'api_key': openai_api_key, - 'api_provider': openai_type, - 'temperature': temperature, - 'top_p': top_p, - 'api_base': None + "api_key": openai_api_key, + "api_provider": openai_type, + "temperature": temperature, + "top_p": top_p, + "api_base": None, } - if openai_type and openai_type == 'openai': - self.conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo-instruct', - max_tokens=500, **openai_kwargs) - self.question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', - max_tokens=500, **openai_kwargs) + if openai_type and openai_type == "openai": + self.conv_simulator_lm = OpenAIModel( + model="gpt-3.5-turbo-instruct", max_tokens=500, **openai_kwargs + ) + self.question_asker_lm = OpenAIModel( + model="gpt-3.5-turbo", max_tokens=500, **openai_kwargs + ) # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) - self.outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', - max_tokens=400, **openai_kwargs) - self.article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13', - max_tokens=700, **openai_kwargs) - self.article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13', - max_tokens=4000, **openai_kwargs) + self.outline_gen_lm = OpenAIModel( + model="gpt-4-0125-preview", max_tokens=400, **openai_kwargs + ) + self.article_gen_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=700, **openai_kwargs + ) + self.article_polish_lm = OpenAIModel( + model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs + ) else: - logging.warning('No valid OpenAI API provider is provided. Cannot use default LLM configurations.') + logging.warning( + "No valid OpenAI API provider is provided. Cannot use default LLM configurations." + ) def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.conv_simulator_lm = model @@ -85,16 +94,21 @@ def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): @dataclass class STORMWikiRunnerArguments: """Arguments for controlling the STORM Wiki pipeline.""" + output_dir: str = field( metadata={"help": "Output directory for the results."}, ) max_conv_turn: int = field( default=3, - metadata={"help": "Maximum number of questions in conversational question asking."}, + metadata={ + "help": "Maximum number of questions in conversational question asking." + }, ) max_perspective: int = field( default=3, - metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."}, + metadata={ + "help": "Maximum number of perspectives to consider in perspective-guided question asking." + }, ) max_search_queries_per_turn: int = field( default=3, @@ -114,24 +128,27 @@ class STORMWikiRunnerArguments: ) max_thread_num: int = field( default=10, - metadata={"help": "Maximum number of threads to use. " - "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."}, + metadata={ + "help": "Maximum number of threads to use. " + "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API." + }, ) class STORMWikiRunner(Engine): """STORM Wiki pipeline runner.""" - def __init__(self, - args: STORMWikiRunnerArguments, - lm_configs: STORMWikiLMConfigs, - rm): + def __init__( + self, args: STORMWikiRunnerArguments, lm_configs: STORMWikiLMConfigs, rm + ): super().__init__(lm_configs=lm_configs) self.args = args self.lm_configs = lm_configs self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k) - storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm) + storm_persona_generator = StormPersonaGenerator( + self.lm_configs.question_asker_lm + ) self.storm_knowledge_curation_module = StormKnowledgeCurationModule( retriever=self.retriever, persona_generator=storm_persona_generator, @@ -140,7 +157,7 @@ def __init__(self, max_search_queries_per_turn=self.args.max_search_queries_per_turn, search_top_k=self.args.search_top_k, max_conv_turn=self.args.max_conv_turn, - max_thread_num=self.args.max_thread_num + max_thread_num=self.args.max_thread_num, ) self.storm_outline_generation_module = StormOutlineGenerationModule( outline_gen_lm=self.lm_configs.outline_gen_lm @@ -148,73 +165,96 @@ def __init__(self, self.storm_article_generation = StormArticleGenerationModule( article_gen_lm=self.lm_configs.article_gen_lm, retrieve_top_k=self.args.retrieve_top_k, - max_thread_num=self.args.max_thread_num + max_thread_num=self.args.max_thread_num, ) self.storm_article_polishing_module = StormArticlePolishingModule( article_gen_lm=self.lm_configs.article_gen_lm, - article_polish_lm=self.lm_configs.article_polish_lm + article_polish_lm=self.lm_configs.article_polish_lm, ) self.lm_configs.init_check() self.apply_decorators() - def run_knowledge_curation_module(self, - ground_truth_url: str = "None", - callback_handler: BaseCallbackHandler = None) -> StormInformationTable: - - information_table, conversation_log = self.storm_knowledge_curation_module.research( - topic=self.topic, - ground_truth_url=ground_truth_url, - callback_handler=callback_handler, - max_perspective=self.args.max_perspective, - disable_perspective=False, - return_conversation_log=True + def run_knowledge_curation_module( + self, + ground_truth_url: str = "None", + callback_handler: BaseCallbackHandler = None, + ) -> StormInformationTable: + + information_table, conversation_log = ( + self.storm_knowledge_curation_module.research( + topic=self.topic, + ground_truth_url=ground_truth_url, + callback_handler=callback_handler, + max_perspective=self.args.max_perspective, + disable_perspective=False, + return_conversation_log=True, + ) ) - FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json')) - information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json')) + FileIOHelper.dump_json( + conversation_log, + os.path.join(self.article_output_dir, "conversation_log.json"), + ) + information_table.dump_url_to_info( + os.path.join(self.article_output_dir, "raw_search_results.json") + ) return information_table - def run_outline_generation_module(self, - information_table: StormInformationTable, - callback_handler: BaseCallbackHandler = None) -> StormArticle: + def run_outline_generation_module( + self, + information_table: StormInformationTable, + callback_handler: BaseCallbackHandler = None, + ) -> StormArticle: outline, draft_outline = self.storm_outline_generation_module.generate_outline( topic=self.topic, information_table=information_table, return_draft_outline=True, - callback_handler=callback_handler + callback_handler=callback_handler, + ) + outline.dump_outline_to_file( + os.path.join(self.article_output_dir, "storm_gen_outline.txt") + ) + draft_outline.dump_outline_to_file( + os.path.join(self.article_output_dir, "direct_gen_outline.txt") ) - outline.dump_outline_to_file(os.path.join(self.article_output_dir, 'storm_gen_outline.txt')) - draft_outline.dump_outline_to_file(os.path.join(self.article_output_dir, "direct_gen_outline.txt")) return outline - def run_article_generation_module(self, - outline: StormArticle, - information_table=StormInformationTable, - callback_handler: BaseCallbackHandler = None) -> StormArticle: + def run_article_generation_module( + self, + outline: StormArticle, + information_table=StormInformationTable, + callback_handler: BaseCallbackHandler = None, + ) -> StormArticle: draft_article = self.storm_article_generation.generate_article( topic=self.topic, information_table=information_table, article_with_outline=outline, - callback_handler=callback_handler + callback_handler=callback_handler, + ) + draft_article.dump_article_as_plain_text( + os.path.join(self.article_output_dir, "storm_gen_article.txt") + ) + draft_article.dump_reference_to_file( + os.path.join(self.article_output_dir, "url_to_info.json") ) - draft_article.dump_article_as_plain_text(os.path.join(self.article_output_dir, 'storm_gen_article.txt')) - draft_article.dump_reference_to_file(os.path.join(self.article_output_dir, 'url_to_info.json')) return draft_article - def run_article_polishing_module(self, - draft_article: StormArticle, - remove_duplicate: bool = False) -> StormArticle: + def run_article_polishing_module( + self, draft_article: StormArticle, remove_duplicate: bool = False + ) -> StormArticle: polished_article = self.storm_article_polishing_module.polish_article( topic=self.topic, draft_article=draft_article, - remove_duplicate=remove_duplicate + remove_duplicate=remove_duplicate, + ) + FileIOHelper.write_str( + polished_article.to_string(), + os.path.join(self.article_output_dir, "storm_gen_article_polished.txt"), ) - FileIOHelper.write_str(polished_article.to_string(), - os.path.join(self.article_output_dir, 'storm_gen_article_polished.txt')) return polished_article def post_run(self): @@ -224,43 +264,61 @@ def post_run(self): 2. Dumping the LLM call history. """ config_log = self.lm_configs.log() - FileIOHelper.dump_json(config_log, os.path.join(self.article_output_dir, 'run_config.json')) + FileIOHelper.dump_json( + config_log, os.path.join(self.article_output_dir, "run_config.json") + ) llm_call_history = self.lm_configs.collect_and_reset_lm_history() - with open(os.path.join(self.article_output_dir, 'llm_call_history.jsonl'), 'w') as f: + with open( + os.path.join(self.article_output_dir, "llm_call_history.jsonl"), "w" + ) as f: for call in llm_call_history: - if 'kwargs' in call: - call.pop('kwargs') # All kwargs are dumped together to run_config.json. - f.write(json.dumps(call) + '\n') + if "kwargs" in call: + call.pop( + "kwargs" + ) # All kwargs are dumped together to run_config.json. + f.write(json.dumps(call) + "\n") def _load_information_table_from_local_fs(self, information_table_local_path): assert os.path.exists(information_table_local_path), makeStringRed( - f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.") - return StormInformationTable.from_conversation_log_file(information_table_local_path) + f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic." + ) + return StormInformationTable.from_conversation_log_file( + information_table_local_path + ) def _load_outline_from_local_fs(self, topic, outline_local_path): assert os.path.exists(outline_local_path), makeStringRed( - f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.") + f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic." + ) return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path) - def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path): + def _load_draft_article_from_local_fs( + self, topic, draft_article_path, url_to_info_path + ): assert os.path.exists(draft_article_path), makeStringRed( - f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.") + f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic." + ) assert os.path.exists(url_to_info_path), makeStringRed( - f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.") + f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic." + ) article_text = FileIOHelper.load_str(draft_article_path) references = FileIOHelper.load_json(url_to_info_path) - return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references) - - def run(self, - topic: str, - ground_truth_url: str = '', - do_research: bool = True, - do_generate_outline: bool = True, - do_generate_article: bool = True, - do_polish_article: bool = True, - remove_duplicate: bool = False, - callback_handler: BaseCallbackHandler = BaseCallbackHandler()): + return StormArticle.from_string( + topic_name=topic, article_text=article_text, references=references + ) + + def run( + self, + topic: str, + ground_truth_url: str = "", + do_research: bool = True, + do_generate_outline: bool = True, + do_generate_article: bool = True, + do_polish_article: bool = True, + remove_duplicate: bool = False, + callback_handler: BaseCallbackHandler = BaseCallbackHandler(), + ): """ Run the STORM pipeline. @@ -278,50 +336,76 @@ def run(self, remove_duplicate: If True, remove duplicated content. callback_handler: A callback handler to handle the intermediate results. """ - assert do_research or do_generate_outline or do_generate_article or do_polish_article, \ - makeStringRed( - "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article") + assert ( + do_research + or do_generate_outline + or do_generate_article + or do_polish_article + ), makeStringRed( + "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article" + ) self.topic = topic - self.article_dir_name = truncate_filename(topic.replace(' ', '_').replace('/', '_')) - self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name) + self.article_dir_name = truncate_filename( + topic.replace(" ", "_").replace("/", "_") + ) + self.article_output_dir = os.path.join( + self.args.output_dir, self.article_dir_name + ) os.makedirs(self.article_output_dir, exist_ok=True) # research module information_table: StormInformationTable = None if do_research: - information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url, - callback_handler=callback_handler) + information_table = self.run_knowledge_curation_module( + ground_truth_url=ground_truth_url, callback_handler=callback_handler + ) # outline generation module outline: StormArticle = None if do_generate_outline: # load information table if it's not initialized if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, 'conversation_log.json')) - outline = self.run_outline_generation_module(information_table=information_table, - callback_handler=callback_handler) + os.path.join(self.article_output_dir, "conversation_log.json") + ) + outline = self.run_outline_generation_module( + information_table=information_table, callback_handler=callback_handler + ) # article generation module draft_article: StormArticle = None if do_generate_article: if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, 'conversation_log.json')) + os.path.join(self.article_output_dir, "conversation_log.json") + ) if outline is None: - outline = self._load_outline_from_local_fs(topic=topic, - outline_local_path=os.path.join(self.article_output_dir, - 'storm_gen_outline.txt')) - draft_article = self.run_article_generation_module(outline=outline, - information_table=information_table, - callback_handler=callback_handler) + outline = self._load_outline_from_local_fs( + topic=topic, + outline_local_path=os.path.join( + self.article_output_dir, "storm_gen_outline.txt" + ), + ) + draft_article = self.run_article_generation_module( + outline=outline, + information_table=information_table, + callback_handler=callback_handler, + ) # article polishing module if do_polish_article: if draft_article is None: - draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt') - url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json') - draft_article = self._load_draft_article_from_local_fs(topic=topic, - draft_article_path=draft_article_path, - url_to_info_path=url_to_info_path) - self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate) + draft_article_path = os.path.join( + self.article_output_dir, "storm_gen_article.txt" + ) + url_to_info_path = os.path.join( + self.article_output_dir, "url_to_info.json" + ) + draft_article = self._load_draft_article_from_local_fs( + topic=topic, + draft_article_path=draft_article_path, + url_to_info_path=url_to_info_path, + ) + self.run_article_polishing_module( + draft_article=draft_article, remove_duplicate=remove_duplicate + ) diff --git a/knowledge_storm/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py index a114b3e..2e71146 100644 --- a/knowledge_storm/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -15,35 +15,48 @@ class StormArticleGenerationModule(ArticleGenerationModule): """ The interface for article generation stage. Given topic, collected information from - knowledge curation stage, generated outline from outline generation stage, + knowledge curation stage, generated outline from outline generation stage, """ - def __init__(self, - article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], - retrieve_top_k: int = 5, - max_thread_num: int = 10): + def __init__( + self, + article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], + retrieve_top_k: int = 5, + max_thread_num: int = 10, + ): super().__init__() self.retrieve_top_k = retrieve_top_k self.article_gen_lm = article_gen_lm self.max_thread_num = max_thread_num self.section_gen = ConvToSection(engine=self.article_gen_lm) - def generate_section(self, topic, section_name, information_table, section_outline, section_query): + def generate_section( + self, topic, section_name, information_table, section_outline, section_query + ): collected_info: List[StormInformation] = [] if information_table is not None: - collected_info = information_table.retrieve_information(queries=section_query, - search_top_k=self.retrieve_top_k) - output = self.section_gen(topic=topic, - outline=section_outline, - section=section_name, - collected_info=collected_info) - return {"section_name": section_name, "section_content": output.section, "collected_info": collected_info} - - def generate_article(self, - topic: str, - information_table: StormInformationTable, - article_with_outline: StormArticle, - callback_handler: BaseCallbackHandler = None) -> StormArticle: + collected_info = information_table.retrieve_information( + queries=section_query, search_top_k=self.retrieve_top_k + ) + output = self.section_gen( + topic=topic, + outline=section_outline, + section=section_name, + collected_info=collected_info, + ) + return { + "section_name": section_name, + "section_content": output.section, + "collected_info": collected_info, + } + + def generate_article( + self, + topic: str, + information_table: StormInformationTable, + article_with_outline: StormArticle, + callback_handler: BaseCallbackHandler = None, + ) -> StormArticle: """ Generate article for the topic based on the information table and article outline. @@ -63,35 +76,48 @@ def generate_article(self, section_output_dict_collection = [] if len(sections_to_write) == 0: - logging.error(f'No outline for {topic}. Will directly search with the topic.') + logging.error( + f"No outline for {topic}. Will directly search with the topic." + ) section_output_dict = self.generate_section( topic=topic, section_name=topic, information_table=information_table, section_outline="", - section_query=[topic] + section_query=[topic], ) section_output_dict_collection = [section_output_dict] else: - with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_thread_num + ) as executor: future_to_sec_title = {} for section_title in sections_to_write: # We don't want to write a separate introduction section. - if section_title.lower().strip() == 'introduction': + if section_title.lower().strip() == "introduction": continue # We don't want to write a separate conclusion section. if section_title.lower().strip().startswith( - 'conclusion') or section_title.lower().strip().startswith('summary'): + "conclusion" + ) or section_title.lower().strip().startswith("summary"): continue - section_query = article_with_outline.get_outline_as_list(root_section_name=section_title, - add_hashtags=False) + section_query = article_with_outline.get_outline_as_list( + root_section_name=section_title, add_hashtags=False + ) queries_with_hashtags = article_with_outline.get_outline_as_list( - root_section_name=section_title, add_hashtags=True) + root_section_name=section_title, add_hashtags=True + ) section_outline = "\n".join(queries_with_hashtags) future_to_sec_title[ - executor.submit(self.generate_section, - topic, section_title, information_table, section_outline, section_query) + executor.submit( + self.generate_section, + topic, + section_title, + information_table, + section_outline, + section_query, + ) ] = section_title for future in as_completed(future_to_sec_title): @@ -99,9 +125,11 @@ def generate_article(self, article = copy.deepcopy(article_with_outline) for section_output_dict in section_output_dict_collection: - article.update_section(parent_section_name=topic, - current_section_content=section_output_dict["section_content"], - current_section_info_list=section_output_dict["collected_info"]) + article.update_section( + parent_section_name=topic, + current_section_content=section_output_dict["section_content"], + current_section_info_list=section_output_dict["collected_info"], + ) article.post_processing() return article @@ -114,17 +142,24 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_section = dspy.Predict(WriteSection) self.engine = engine - def forward(self, topic: str, outline: str, section: str, collected_info: List[StormInformation]): - info = '' + def forward( + self, + topic: str, + outline: str, + section: str, + collected_info: List[StormInformation], + ): + info = "" for idx, storm_info in enumerate(collected_info): - info += f'[{idx + 1}]\n' + '\n'.join(storm_info.snippets) - info += '\n\n' + info += f"[{idx + 1}]\n" + "\n".join(storm_info.snippets) + info += "\n\n" info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500) with dspy.settings.context(lm=self.engine): section = ArticleTextProcessing.clean_up_section( - self.write_section(topic=topic, info=info, section=section).output) + self.write_section(topic=topic, info=info, section=section).output + ) return dspy.Prediction(section=section) @@ -132,9 +167,9 @@ def forward(self, topic: str, outline: str, section: str, collected_info: List[S class WriteSection(dspy.Signature): """Write a Wikipedia section based on the collected information. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. """ info = dspy.InputField(prefix="The collected information:\n", format=str) @@ -142,5 +177,5 @@ class WriteSection(dspy.Signature): section = dspy.InputField(prefix="The section you need to write: ", format=str) output = dspy.OutputField( prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n", - format=str + format=str, ) diff --git a/knowledge_storm/storm_wiki/modules/article_polish.py b/knowledge_storm/storm_wiki/modules/article_polish.py index b70bb83..fb85b0f 100644 --- a/knowledge_storm/storm_wiki/modules/article_polish.py +++ b/knowledge_storm/storm_wiki/modules/article_polish.py @@ -14,21 +14,21 @@ class StormArticlePolishingModule(ArticlePolishingModule): knowledge curation stage, generated outline from outline generation stage. """ - def __init__(self, - article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__( + self, + article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + ): self.article_gen_lm = article_gen_lm self.article_polish_lm = article_polish_lm self.polish_page = PolishPageModule( - write_lead_engine=self.article_gen_lm, - polish_engine=self.article_polish_lm + write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm ) - def polish_article(self, - topic: str, - draft_article: StormArticle, - remove_duplicate: bool = False) -> StormArticle: + def polish_article( + self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False + ) -> StormArticle: """ Polish article. @@ -39,10 +39,14 @@ def polish_article(self, """ article_text = draft_article.to_string() - polish_result = self.polish_page(topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate) + polish_result = self.polish_page( + topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate + ) lead_section = f"# summary\n{polish_result.lead_section}" - polished_article = '\n\n'.join([lead_section, polish_result.page]) - polished_article_dict = ArticleTextProcessing.parse_article_into_dict(polished_article) + polished_article = "\n\n".join([lead_section, polish_result.page]) + polished_article_dict = ArticleTextProcessing.parse_article_into_dict( + polished_article + ) polished_article = copy.deepcopy(draft_article) polished_article.insert_or_create_section(article_dict=polished_article_dict) polished_article.post_processing() @@ -51,9 +55,10 @@ def polish_article(self, class WriteLeadSection(dspy.Signature): """Write a lead section for the given Wikipedia page with the following guidelines: - 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. - 2. The lead section should be concise and contain no more than four well-composed paragraphs. - 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary.""" + 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. + 2. The lead section should be concise and contain no more than four well-composed paragraphs. + 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary. + """ topic = dspy.InputField(prefix="The topic of the page: ", format=str) draft_page = dspy.InputField(prefix="The draft page:\n", format=str) @@ -68,8 +73,11 @@ class PolishPage(dspy.Signature): class PolishPageModule(dspy.Module): - def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__( + self, + write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + ): super().__init__() self.write_lead_engine = write_lead_engine self.polish_engine = polish_engine @@ -78,7 +86,9 @@ def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True): with dspy.settings.context(lm=self.write_lead_engine): - lead_section = self.write_lead(topic=topic, draft_page=draft_page).lead_section + lead_section = self.write_lead( + topic=topic, draft_page=draft_page + ).lead_section if "The lead section:" in lead_section: lead_section = lead_section.split("The lead section:")[1].strip() if polish_whole_page: diff --git a/knowledge_storm/storm_wiki/modules/knowledge_curation.py b/knowledge_storm/storm_wiki/modules/knowledge_curation.py index 8e881c6..bde2767 100644 --- a/knowledge_storm/storm_wiki/modules/knowledge_curation.py +++ b/knowledge_storm/storm_wiki/modules/knowledge_curation.py @@ -25,20 +25,32 @@ class ConvSimulator(dspy.Module): """Simulate a conversation between a Wikipedia writer with specific persona and an expert.""" - def __init__(self, topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - retriever: Retriever, max_search_queries_per_turn: int, search_top_k: int, max_turn: int): + def __init__( + self, + topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + retriever: Retriever, + max_search_queries_per_turn: int, + search_top_k: int, + max_turn: int, + ): super().__init__() self.wiki_writer = WikiWriter(engine=question_asker_engine) self.topic_expert = TopicExpert( engine=topic_expert_engine, max_search_queries=max_search_queries_per_turn, search_top_k=search_top_k, - retriever=retriever + retriever=retriever, ) self.max_turn = max_turn - def forward(self, topic: str, persona: str, ground_truth_url: str, callback_handler: BaseCallbackHandler): + def forward( + self, + topic: str, + persona: str, + ground_truth_url: str, + callback_handler: BaseCallbackHandler, + ): """ topic: The topic to research. persona: The persona of the Wikipedia writer. @@ -46,18 +58,22 @@ def forward(self, topic: str, persona: str, ground_truth_url: str, callback_hand """ dlg_history: List[DialogueTurn] = [] for _ in range(self.max_turn): - user_utterance = self.wiki_writer(topic=topic, persona=persona, dialogue_turns=dlg_history).question - if user_utterance == '': - logging.error('Simulated Wikipedia writer utterance is empty.') + user_utterance = self.wiki_writer( + topic=topic, persona=persona, dialogue_turns=dlg_history + ).question + if user_utterance == "": + logging.error("Simulated Wikipedia writer utterance is empty.") break - if user_utterance.startswith('Thank you so much for your help!'): + if user_utterance.startswith("Thank you so much for your help!"): break - expert_output = self.topic_expert(topic=topic, question=user_utterance, ground_truth_url=ground_truth_url) + expert_output = self.topic_expert( + topic=topic, question=user_utterance, ground_truth_url=ground_truth_url + ) dlg_turn = DialogueTurn( agent_utterance=expert_output.answer, user_utterance=user_utterance, search_queries=expert_output.queries, - search_results=expert_output.searched_results + search_results=expert_output.searched_results, ) dlg_history.append(dlg_turn) callback_handler.on_dialogue_turn_end(dlg_turn=dlg_turn) @@ -76,22 +92,35 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.ask_question = dspy.ChainOfThought(AskQuestion) self.engine = engine - def forward(self, topic: str, persona: str, dialogue_turns: List[DialogueTurn], draft_page=None): + def forward( + self, + topic: str, + persona: str, + dialogue_turns: List[DialogueTurn], + draft_page=None, + ): conv = [] for turn in dialogue_turns[:-4]: - conv.append(f'You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit.') + conv.append( + f"You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit." + ) for turn in dialogue_turns[-4:]: conv.append( - f'You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}') - conv = '\n'.join(conv) - conv = conv.strip() or 'N/A' + f"You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}" + ) + conv = "\n".join(conv) + conv = conv.strip() or "N/A" conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 2500) with dspy.settings.context(lm=self.engine): if persona is not None and len(persona.strip()) > 0: - question = self.ask_question_with_persona(topic=topic, persona=persona, conv=conv).question + question = self.ask_question_with_persona( + topic=topic, persona=persona, conv=conv + ).question else: - question = self.ask_question(topic=topic, persona=persona, conv=conv).question + question = self.ask_question( + topic=topic, persona=persona, conv=conv + ).question return dspy.Prediction(question=question) @@ -99,10 +128,11 @@ def forward(self, topic: str, persona: str, dialogue_turns: List[DialogueTurn], class AskQuestion(dspy.Signature): """You are an experienced Wikipedia writer. You are chatting with an expert to get information for the topic you want to contribute. Ask good questions to get more useful information relevant to the topic. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. + """ - topic = dspy.InputField(prefix='Topic you want to write: ', format=str) - conv = dspy.InputField(prefix='Conversation history:\n', format=str) + topic = dspy.InputField(prefix="Topic you want to write: ", format=str) + conv = dspy.InputField(prefix="Conversation history:\n", format=str) question = dspy.OutputField(format=str) @@ -110,38 +140,41 @@ class AskQuestionWithPersona(dspy.Signature): """You are an experienced Wikipedia writer and want to edit a specific page. Besides your identity as a Wikipedia writer, you have specific focus when researching the topic. Now, you are chatting with an expert to get information. Ask good questions to get more useful information. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. + """ - topic = dspy.InputField(prefix='Topic you want to write: ', format=str) - persona = dspy.InputField(prefix='Your persona besides being a Wikipedia writer: ', format=str) - conv = dspy.InputField(prefix='Conversation history:\n', format=str) + topic = dspy.InputField(prefix="Topic you want to write: ", format=str) + persona = dspy.InputField( + prefix="Your persona besides being a Wikipedia writer: ", format=str + ) + conv = dspy.InputField(prefix="Conversation history:\n", format=str) question = dspy.OutputField(format=str) class QuestionToQuery(dspy.Signature): """You want to answer the question using Google search. What do you type in the search box? - Write the queries you will use in the following format: - - query 1 - - query 2 - ... - - query n""" - - topic = dspy.InputField(prefix='Topic you are discussing about: ', format=str) - question = dspy.InputField(prefix='Question you want to answer: ', format=str) + Write the queries you will use in the following format: + - query 1 + - query 2 + ... + - query n""" + + topic = dspy.InputField(prefix="Topic you are discussing about: ", format=str) + question = dspy.InputField(prefix="Question you want to answer: ", format=str) queries = dspy.OutputField(format=str) class AnswerQuestion(dspy.Signature): """You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response. - Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.".""" + Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.". + """ - topic = dspy.InputField(prefix='Topic you are discussing about:', format=str) - conv = dspy.InputField(prefix='Question:\n', format=str) - info = dspy.InputField( - prefix='Gathered information:\n', format=str) + topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) + conv = dspy.InputField(prefix="Question:\n", format=str) + info = dspy.InputField(prefix="Gathered information:\n", format=str) answer = dspy.OutputField( - prefix='Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n', - format=str + prefix="Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n", + format=str, ) @@ -153,8 +186,13 @@ class TopicExpert(dspy.Module): 4. Generate an answer using the retrieved information. """ - def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries: int, search_top_k: int, retriever: Retriever): + def __init__( + self, + engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries: int, + search_top_k: int, + retriever: Retriever, + ): super().__init__() self.generate_queries = dspy.Predict(QuestionToQuery) self.retriever = retriever @@ -168,31 +206,43 @@ def forward(self, topic: str, question: str, ground_truth_url: str): with dspy.settings.context(lm=self.engine): # Identify: Break down question into queries. queries = self.generate_queries(topic=topic, question=question).queries - queries = [q.replace('-', '').strip().strip('"').strip('"').strip() for q in queries.split('\n')] - queries = queries[:self.max_search_queries] + queries = [ + q.replace("-", "").strip().strip('"').strip('"').strip() + for q in queries.split("\n") + ] + queries = queries[: self.max_search_queries] # Search - searched_results: List[StormInformation] = self.retriever.retrieve(list(set(queries)), - exclude_urls=[ground_truth_url]) + searched_results: List[StormInformation] = self.retriever.retrieve( + list(set(queries)), exclude_urls=[ground_truth_url] + ) if len(searched_results) > 0: # Evaluate: Simplify this part by directly using the top 1 snippet. - info = '' + info = "" for n, r in enumerate(searched_results): - info += '\n'.join(f'[{n + 1}]: {s}' for s in r.snippets[:1]) - info += '\n\n' + info += "\n".join(f"[{n + 1}]: {s}" for s in r.snippets[:1]) + info += "\n\n" - info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1000) + info = ArticleTextProcessing.limit_word_count_preserve_newline( + info, 1000 + ) try: - answer = self.answer_question(topic=topic, conv=question, info=info).answer - answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(answer) + answer = self.answer_question( + topic=topic, conv=question, info=info + ).answer + answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + answer + ) except Exception as e: - logging.error(f'Error occurs when generating answer: {e}') - answer = 'Sorry, I cannot answer this question. Please ask another question.' + logging.error(f"Error occurs when generating answer: {e}") + answer = "Sorry, I cannot answer this question. Please ask another question." else: # When no information is found, the expert shouldn't hallucinate. - answer = 'Sorry, I cannot find information for this question. Please ask another question.' + answer = "Sorry, I cannot find information for this question. Please ask another question." - return dspy.Prediction(queries=queries, searched_results=searched_results, answer=answer) + return dspy.Prediction( + queries=queries, searched_results=searched_results, answer=answer + ) class StormKnowledgeCurationModule(KnowledgeCurationModule): @@ -200,15 +250,17 @@ class StormKnowledgeCurationModule(KnowledgeCurationModule): The interface for knowledge curation stage. Given topic, return collected information. """ - def __init__(self, - retriever: Retriever, - persona_generator: Optional[StormPersonaGenerator], - conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries_per_turn: int, - search_top_k: int, - max_conv_turn: int, - max_thread_num: int): + def __init__( + self, + retriever: Retriever, + persona_generator: Optional[StormPersonaGenerator], + conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries_per_turn: int, + search_top_k: int, + max_conv_turn: int, + max_thread_num: int, + ): """ Store args and finish initialization. """ @@ -224,14 +276,22 @@ def __init__(self, retriever=retriever, max_search_queries_per_turn=max_search_queries_per_turn, search_top_k=search_top_k, - max_turn=max_conv_turn + max_turn=max_conv_turn, ) def _get_considered_personas(self, topic: str, max_num_persona) -> List[str]: - return self.persona_generator.generate_persona(topic=topic, max_num_persona=max_num_persona) + return self.persona_generator.generate_persona( + topic=topic, max_num_persona=max_num_persona + ) - def _run_conversation(self, conv_simulator, topic, ground_truth_url, considered_personas, - callback_handler: BaseCallbackHandler) -> List[Tuple[str, List[DialogueTurn]]]: + def _run_conversation( + self, + conv_simulator, + topic, + ground_truth_url, + considered_personas, + callback_handler: BaseCallbackHandler, + ) -> List[Tuple[str, List[DialogueTurn]]]: """ Executes multiple conversation simulations concurrently, each with a different persona, and collects their dialog histories. The dialog history of each conversation is cleaned @@ -260,13 +320,16 @@ def run_conv(persona): topic=topic, ground_truth_url=ground_truth_url, persona=persona, - callback_handler=callback_handler + callback_handler=callback_handler, ) max_workers = min(self.max_thread_num, len(considered_personas)) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_persona = {executor.submit(run_conv, persona): persona for persona in considered_personas} + future_to_persona = { + executor.submit(run_conv, persona): persona + for persona in considered_personas + } if streamlit_connection: # Ensure the logging context is correct when connecting with Streamlit frontend. @@ -276,23 +339,27 @@ def run_conv(persona): for future in as_completed(future_to_persona): persona = future_to_persona[future] conv = future.result() - conversations.append((persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history)) + conversations.append( + (persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history) + ) return conversations - def research(self, - topic: str, - ground_truth_url: str, - callback_handler: BaseCallbackHandler, - max_perspective: int = 0, - disable_perspective: bool = True, - return_conversation_log=False) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: + def research( + self, + topic: str, + ground_truth_url: str, + callback_handler: BaseCallbackHandler, + max_perspective: int = 0, + disable_perspective: bool = True, + return_conversation_log=False, + ) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: """ Curate information and knowledge for the given topic Args: topic: topic of interest in natural language. - + Returns: collected_information: collected information in InformationTable type. """ @@ -303,19 +370,25 @@ def research(self, if disable_perspective: considered_personas = [""] else: - considered_personas = self._get_considered_personas(topic=topic, max_num_persona=max_perspective) + considered_personas = self._get_considered_personas( + topic=topic, max_num_persona=max_perspective + ) callback_handler.on_identify_perspective_end(perspectives=considered_personas) - # run conversation + # run conversation callback_handler.on_information_gathering_start() - conversations = self._run_conversation(conv_simulator=self.conv_simulator, - topic=topic, - ground_truth_url=ground_truth_url, - considered_personas=considered_personas, - callback_handler=callback_handler) + conversations = self._run_conversation( + conv_simulator=self.conv_simulator, + topic=topic, + ground_truth_url=ground_truth_url, + considered_personas=considered_personas, + callback_handler=callback_handler, + ) information_table = StormInformationTable(conversations) callback_handler.on_information_gathering_end() if return_conversation_log: - return information_table, StormInformationTable.construct_log_dict(conversations) + return information_table, StormInformationTable.construct_log_dict( + conversations + ) return information_table diff --git a/knowledge_storm/storm_wiki/modules/outline_generation.py b/knowledge_storm/storm_wiki/modules/outline_generation.py index 1f45b1c..a96c797 100644 --- a/knowledge_storm/storm_wiki/modules/outline_generation.py +++ b/knowledge_storm/storm_wiki/modules/outline_generation.py @@ -14,18 +14,19 @@ class StormOutlineGenerationModule(OutlineGenerationModule): curation stage, generate outline for the article. """ - def __init__(self, - outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): super().__init__() self.outline_gen_lm = outline_gen_lm self.write_outline = WriteOutline(engine=self.outline_gen_lm) - def generate_outline(self, - topic: str, - information_table: StormInformationTable, - old_outline: Optional[StormArticle] = None, - callback_handler: BaseCallbackHandler = None, - return_draft_outline=False) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: + def generate_outline( + self, + topic: str, + information_table: StormInformationTable, + old_outline: Optional[StormArticle] = None, + callback_handler: BaseCallbackHandler = None, + return_draft_outline=False, + ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: """ Generates an outline for an article based on the specified topic and the information gathered during the knowledge curation stage. This method can optionally return both the @@ -34,30 +35,38 @@ def generate_outline(self, Args: topic (str): The topic of the article. information_table (StormInformationTable): The information table containing the collected information. - old_outline (Optional[StormArticle]): An optional previous version of the article outline that can + old_outline (Optional[StormArticle]): An optional previous version of the article outline that can be used for reference or comparison. Defaults to None. - callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger - custom callbacks at various stages of the outline generation process, such as when the information + callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger + custom callbacks at various stages of the outline generation process, such as when the information organization starts. Defaults to None. - return_draft_outline (bool): A flag indicating whether the method should return both the final article - outline and a draft version of the outline. If False, only the final article outline is returned. + return_draft_outline (bool): A flag indicating whether the method should return both the final article + outline and a draft version of the outline. If False, only the final article outline is returned. Defaults to False. Returns: - Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, - this method returns either a single `StormArticle` object containing the final outline or a tuple of - two `StormArticle` objects, the first containing the final outline and the second containing the + Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, + this method returns either a single `StormArticle` object containing the final outline or a tuple of + two `StormArticle` objects, the first containing the final outline and the second containing the draft outline. """ if callback_handler is not None: callback_handler.on_information_organization_start() - concatenated_dialogue_turns = sum([conv for (_, conv) in information_table.conversations], []) - result = self.write_outline(topic=topic, dlg_history=concatenated_dialogue_turns, - callback_handler=callback_handler) - article_with_outline_only = StormArticle.from_outline_str(topic=topic, outline_str=result.outline) - article_with_draft_outline_only = StormArticle.from_outline_str(topic=topic, - outline_str=result.old_outline) + concatenated_dialogue_turns = sum( + [conv for (_, conv) in information_table.conversations], [] + ) + result = self.write_outline( + topic=topic, + dlg_history=concatenated_dialogue_turns, + callback_handler=callback_handler, + ) + article_with_outline_only = StormArticle.from_outline_str( + topic=topic, outline_str=result.outline + ) + article_with_draft_outline_only = StormArticle.from_outline_str( + topic=topic, outline_str=result.old_outline + ) if not return_draft_outline: return article_with_outline_only return article_with_outline_only, article_with_draft_outline_only @@ -72,25 +81,44 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_page_outline = dspy.Predict(WritePageOutlineFromConv) self.engine = engine - def forward(self, topic: str, dlg_history, old_outline: Optional[str] = None, - callback_handler: BaseCallbackHandler = None): + def forward( + self, + topic: str, + dlg_history, + old_outline: Optional[str] = None, + callback_handler: BaseCallbackHandler = None, + ): trimmed_dlg_history = [] for turn in dlg_history: - if 'topic you' in turn.agent_utterance.lower() or 'topic you' in turn.user_utterance.lower(): + if ( + "topic you" in turn.agent_utterance.lower() + or "topic you" in turn.user_utterance.lower() + ): continue trimmed_dlg_history.append(turn) - conv = '\n'.join([f'Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}' for turn in - trimmed_dlg_history]) + conv = "\n".join( + [ + f"Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}" + for turn in trimmed_dlg_history + ] + ) conv = ArticleTextProcessing.remove_citations(conv) conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000) with dspy.settings.context(lm=self.engine): if old_outline is None: - old_outline = ArticleTextProcessing.clean_up_outline(self.draft_page_outline(topic=topic).outline) + old_outline = ArticleTextProcessing.clean_up_outline( + self.draft_page_outline(topic=topic).outline + ) if callback_handler: - callback_handler.on_direct_outline_generation_end(outline=old_outline) + callback_handler.on_direct_outline_generation_end( + outline=old_outline + ) outline = ArticleTextProcessing.clean_up_outline( - self.write_page_outline(topic=topic, old_outline=old_outline, conv=conv).outline) + self.write_page_outline( + topic=topic, old_outline=old_outline, conv=conv + ).outline + ) if callback_handler: callback_handler.on_outline_refinement_end(outline=outline) @@ -99,10 +127,10 @@ def forward(self, topic: str, dlg_history, old_outline: Optional[str] = None, class WritePageOutline(dspy.Signature): """Write an outline for a Wikipedia page. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -124,10 +152,10 @@ def forward(self, topic: str): class WritePageOutlineFromConv(dspy.Signature): """Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -135,5 +163,5 @@ class WritePageOutlineFromConv(dspy.Signature): old_outline = dspy.OutputField(prefix="Current outline:\n", format=str) outline = dspy.OutputField( prefix='Write the Wikipedia page outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n', - format=str + format=str, ) diff --git a/knowledge_storm/storm_wiki/modules/persona_generator.py b/knowledge_storm/storm_wiki/modules/persona_generator.py index 5150e31..c51dc0c 100644 --- a/knowledge_storm/storm_wiki/modules/persona_generator.py +++ b/knowledge_storm/storm_wiki/modules/persona_generator.py @@ -11,19 +11,27 @@ def get_wiki_page_title_and_toc(url): """Get the main title and table of contents from an url of a Wikipedia page.""" response = requests.get(url) - soup = BeautifulSoup(response.content, 'html.parser') + soup = BeautifulSoup(response.content, "html.parser") # Get the main title from the first h1 tag - main_title = soup.find('h1').text.replace('[edit]', '').strip().replace('\xa0', ' ') + main_title = soup.find("h1").text.replace("[edit]", "").strip().replace("\xa0", " ") toc = "" levels = [] - excluded_sections = {'Contents', 'See also', 'Notes', 'References', 'External links'} + excluded_sections = { + "Contents", + "See also", + "Notes", + "References", + "External links", + } # Start processing from h2 to exclude the main title from TOC - for header in soup.find_all(['h2', 'h3', "h4", "h5", "h6"]): - level = int(header.name[1]) # Extract the numeric part of the header tag (e.g., '2' from 'h2') - section_title = header.text.replace('[edit]', '').strip().replace('\xa0', ' ') + for header in soup.find_all(["h2", "h3", "h4", "h5", "h6"]): + level = int( + header.name[1] + ) # Extract the numeric part of the header tag (e.g., '2' from 'h2') + section_title = header.text.replace("[edit]", "").strip().replace("\xa0", " ") if section_title in excluded_sections: continue @@ -39,9 +47,9 @@ def get_wiki_page_title_and_toc(url): class FindRelatedTopic(dspy.Signature): """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics. - Please list the urls in separate lines.""" + Please list the urls in separate lines.""" - topic = dspy.InputField(prefix='Topic of interest:', format=str) + topic = dspy.InputField(prefix="Topic of interest:", format=str) related_topics = dspy.OutputField(format=str) @@ -50,8 +58,10 @@ class GenPersona(dspy.Signature): Give your answer in the following format: 1. short summary of editor 1: description\n2. short summary of editor 2: description\n... """ - topic = dspy.InputField(prefix='Topic of interest:', format=str) - examples = dspy.InputField(prefix='Wiki page outlines of related topics for inspiration:\n', format=str) + topic = dspy.InputField(prefix="Topic of interest:", format=str) + examples = dspy.InputField( + prefix="Wiki page outlines of related topics for inspiration:\n", format=str + ) personas = dspy.OutputField(format=str) @@ -69,38 +79,44 @@ def forward(self, topic: str, draft=None): # Get section names from wiki pages of relevant topics for inspiration. related_topics = self.find_related_topic(topic=topic).related_topics urls = [] - for s in related_topics.split('\n'): - if 'http' in s: - urls.append(s[s.find('http'):]) + for s in related_topics.split("\n"): + if "http" in s: + urls.append(s[s.find("http") :]) examples = [] for url in urls: try: title, toc = get_wiki_page_title_and_toc(url) - examples.append(f'Title: {title}\nTable of Contents: {toc}') + examples.append(f"Title: {title}\nTable of Contents: {toc}") except Exception as e: - logging.error(f'Error occurs when processing {url}: {e}') + logging.error(f"Error occurs when processing {url}: {e}") continue if len(examples) == 0: - examples.append('N/A') - gen_persona_output = self.gen_persona(topic=topic, examples='\n----------\n'.join(examples)).personas + examples.append("N/A") + gen_persona_output = self.gen_persona( + topic=topic, examples="\n----------\n".join(examples) + ).personas personas = [] - for s in gen_persona_output.split('\n'): - match = re.search(r'\d+\.\s*(.*)', s) + for s in gen_persona_output.split("\n"): + match = re.search(r"\d+\.\s*(.*)", s) if match: personas.append(match.group(1)) sorted_personas = personas - return dspy.Prediction(personas=personas, raw_personas_output=sorted_personas, related_topics=related_topics) + return dspy.Prediction( + personas=personas, + raw_personas_output=sorted_personas, + related_topics=related_topics, + ) -class StormPersonaGenerator(): +class StormPersonaGenerator: """ A generator class for creating personas based on a given topic. - This class uses an underlying engine to generate personas tailored to the specified topic. - The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, + This class uses an underlying engine to generate personas tailored to the specified topic. + The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, including a default 'Basic fact writer' persona. Attributes: @@ -133,6 +149,6 @@ def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]: and up to `max_num_persona` additional personas generated based on the topic. """ personas = self.create_writer_with_persona(topic=topic) - default_persona = 'Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic.' + default_persona = "Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic." considered_personas = [default_persona] + personas.personas[:max_num_persona] return considered_personas diff --git a/knowledge_storm/storm_wiki/modules/retriever.py b/knowledge_storm/storm_wiki/modules/retriever.py index 179ae99..85df63e 100644 --- a/knowledge_storm/storm_wiki/modules/retriever.py +++ b/knowledge_storm/storm_wiki/modules/retriever.py @@ -149,7 +149,8 @@ "WordPress.com", "Worldometer", "YouTube", - "ZDNet"} + "ZDNet", +} DEPRECATED = { "Al_Mayadeen", "ANNA_News", @@ -197,7 +198,7 @@ "VDARE", "Voltaire_Network", "WorldNetDaily", - "Zero_Hedge" + "Zero_Hedge", } BLACKLISTED = { "Advameg", @@ -218,7 +219,7 @@ "The_Points_Guy_(sponsored_content)", "Swarajya", "Veterans_Today", - "ZoomInfo" + "ZoomInfo", } @@ -237,14 +238,20 @@ class StormRetriever(Retriever): def __init__(self, rm: dspy.Retrieve, k=3): super().__init__(search_top_k=k) self._rm = rm - if hasattr(rm, 'is_valid_source'): + if hasattr(rm, "is_valid_source"): rm.is_valid_source = is_valid_wikipedia_source - def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]: - retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls) + def retrieve( + self, query: Union[str, List[str]], exclude_urls: List[str] = [] + ) -> List[Information]: + retrieved_data_list = self._rm( + query_or_queries=query, exclude_urls=exclude_urls + ) for data in retrieved_data_list: - for i in range(len(data['snippets'])): + for i in range(len(data["snippets"])): # STORM generate the article with citations. We do not consider multi-hop citations. # Remove citations in the source to avoid confusion. - data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i]) + data["snippets"][i] = ArticleTextProcessing.remove_citations( + data["snippets"][i] + ) return [StormInformation.from_dict(data) for data in retrieved_data_list] diff --git a/knowledge_storm/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py index 4f54ec4..43826ec 100644 --- a/knowledge_storm/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -51,22 +51,29 @@ def from_dict(cls, info_dict): Returns: StormInformation: An instance of StormInformation. """ - return cls(info_dict['url'], info_dict['description'], info_dict['snippets'], info_dict['title']) + return cls( + info_dict["url"], + info_dict["description"], + info_dict["snippets"], + info_dict["title"], + ) def to_dict(self): - return {"url": self.uuid, - "description": self.description, - "snippets": self.snippets, - "title": self.title} + return { + "url": self.uuid, + "description": self.description, + "snippets": self.snippets, + "title": self.title, + } class DialogueTurn: def __init__( - self, - agent_utterance: str = None, - user_utterance: str = None, - search_queries: Optional[List[str]] = None, - search_results: Optional[List[Union[StormInformation, Dict]]] = None + self, + agent_utterance: str = None, + user_utterance: str = None, + search_queries: Optional[List[str]] = None, + search_results: Optional[List[Union[StormInformation, Dict]]] = None, ): self.agent_utterance = agent_utterance self.user_utterance = user_utterance @@ -76,7 +83,9 @@ def __init__( if self.search_results: for idx in range(len(self.search_results)): if type(self.search_results[idx]) == dict: - self.search_results[idx] = StormInformation.from_dict(self.search_results[idx]) + self.search_results[idx] = StormInformation.from_dict( + self.search_results[idx] + ) def log(self): """ @@ -85,10 +94,10 @@ def log(self): return OrderedDict( { - 'agent_utterance': self.agent_utterance, - 'user_utterance': self.user_utterance, - 'search_queries': self.search_queries, - 'search_results': [data.to_dict() for data in self.search_results], + "agent_utterance": self.agent_utterance, + "user_utterance": self.user_utterance, + "search_queries": self.search_queries, + "search_results": [data.to_dict() for data in self.search_results], } ) @@ -98,7 +107,7 @@ class StormInformationTable(InformationTable): The InformationTable class serves as data class to store the information collected during KnowledgeCuration stage. - Create subclass to incorporate more information as needed. For example, + Create subclass to incorporate more information as needed. For example, in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information would be perspective guided dialogue history. """ @@ -106,13 +115,17 @@ class StormInformationTable(InformationTable): def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]): super().__init__() self.conversations = conversations - self.url_to_info: Dict[str, StormInformation] = StormInformationTable.construct_url_to_info(self.conversations) + self.url_to_info: Dict[str, StormInformation] = ( + StormInformationTable.construct_url_to_info(self.conversations) + ) @staticmethod - def construct_url_to_info(conversations: List[Tuple[str, List[DialogueTurn]]]) -> Dict[str, StormInformation]: + def construct_url_to_info( + conversations: List[Tuple[str, List[DialogueTurn]]] + ) -> Dict[str, StormInformation]: url_to_info = {} - for (persona, conv) in conversations: + for persona, conv in conversations: for turn in conv: for storm_info in turn.search_results: if storm_info.url in url_to_info: @@ -124,14 +137,13 @@ def construct_url_to_info(conversations: List[Tuple[str, List[DialogueTurn]]]) - return url_to_info @staticmethod - def construct_log_dict(conversations: List[Tuple[str, List[DialogueTurn]]]) -> List[Dict[str, Union[str, Any]]]: + def construct_log_dict( + conversations: List[Tuple[str, List[DialogueTurn]]] + ) -> List[Dict[str, Union[str, Any]]]: conversation_log = [] - for (persona, conv) in conversations: + for persona, conv in conversations: conversation_log.append( - { - 'perspective': persona, - 'dlg_turns': [turn.log() for turn in conv] - } + {"perspective": persona, "dlg_turns": [turn.log() for turn in conv]} ) return conversation_log @@ -146,22 +158,26 @@ def from_conversation_log_file(cls, path): conversation_log_data = FileIOHelper.load_json(path) conversations = [] for item in conversation_log_data: - dialogue_turns = [DialogueTurn(**turn) for turn in item['dlg_turns']] - persona = item['perspective'] + dialogue_turns = [DialogueTurn(**turn) for turn in item["dlg_turns"]] + persona = item["perspective"] conversations.append((persona, dialogue_turns)) return cls(conversations) def prepare_table_for_retrieval(self): - self.encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2') + self.encoder = SentenceTransformer("paraphrase-MiniLM-L6-v2") self.collected_urls = [] self.collected_snippets = [] for url, information in self.url_to_info.items(): for snippet in information.snippets: self.collected_urls.append(url) self.collected_snippets.append(snippet) - self.encoded_snippets = self.encoder.encode(self.collected_snippets, show_progress_bar=False) + self.encoded_snippets = self.encoder.encode( + self.collected_snippets, show_progress_bar=False + ) - def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> List[StormInformation]: + def retrieve_information( + self, queries: Union[List[str], str], search_top_k + ) -> List[StormInformation]: selected_urls = [] selected_snippets = [] if type(queries) is str: @@ -191,14 +207,13 @@ def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> class StormArticle(Article): def __init__(self, topic_name): super().__init__(topic_name=topic_name) - self.reference = { - "url_to_unified_index": {}, - "url_to_info": {} - } + self.reference = {"url_to_unified_index": {}, "url_to_info": {}} - def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: + def find_section( + self, node: ArticleSectionNode, name: str + ) -> Optional[ArticleSectionNode]: """ - Return the node of the section given the section name. + Return the node of the section given the section name. Args: node: the node as the root to find. @@ -215,17 +230,18 @@ def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleS return result return None - def _merge_new_info_to_references(self, new_info_list: List[StormInformation], index_to_keep=None) -> Dict[ - int, int]: + def _merge_new_info_to_references( + self, new_info_list: List[StormInformation], index_to_keep=None + ) -> Dict[int, int]: """ Merges new storm information into existing references and updates the citation index mapping. Args: - new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. + new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all. Returns: - Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list + Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list to its unified citation index in the references. """ citation_idx_mapping = {} @@ -234,20 +250,32 @@ def _merge_new_info_to_references(self, new_info_list: List[StormInformation], i continue url = storm_info.url if url not in self.reference["url_to_unified_index"]: - self.reference["url_to_unified_index"][url] = len( - self.reference["url_to_unified_index"]) + 1 # The citation index starts from 1. + self.reference["url_to_unified_index"][url] = ( + len(self.reference["url_to_unified_index"]) + 1 + ) # The citation index starts from 1. self.reference["url_to_info"][url] = storm_info else: existing_snippets = self.reference["url_to_info"][url].snippets existing_snippets.extend(storm_info.snippets) - self.reference["url_to_info"][url].snippets = list(set(existing_snippets)) + self.reference["url_to_info"][url].snippets = list( + set(existing_snippets) + ) citation_idx_mapping[idx + 1] = self.reference["url_to_unified_index"][ - url] # The citation index starts from 1. + url + ] # The citation index starts from 1. return citation_idx_mapping - def insert_or_create_section(self, article_dict: Dict[str, Dict], parent_section_name: str = None, - trim_children=False): - parent_node = self.root if parent_section_name is None else self.find_section(self.root, parent_section_name) + def insert_or_create_section( + self, + article_dict: Dict[str, Dict], + parent_section_name: str = None, + trim_children=False, + ): + parent_node = ( + self.root + if parent_section_name is None + else self.find_section(self.root, parent_section_name) + ) if trim_children: section_names = set(article_dict.keys()) @@ -258,56 +286,83 @@ def insert_or_create_section(self, article_dict: Dict[str, Dict], parent_section for section_name, content_dict in article_dict.items(): current_section_node = self.find_section(parent_node, section_name) if current_section_node is None: - current_section_node = ArticleSectionNode(section_name=section_name, - content=content_dict["content"].strip()) - insert_to_front = parent_node.section_name == self.root.section_name and current_section_node.section_name == "summary" - parent_node.add_child(current_section_node, insert_to_front=insert_to_front) + current_section_node = ArticleSectionNode( + section_name=section_name, content=content_dict["content"].strip() + ) + insert_to_front = ( + parent_node.section_name == self.root.section_name + and current_section_node.section_name == "summary" + ) + parent_node.add_child( + current_section_node, insert_to_front=insert_to_front + ) else: current_section_node.content = content_dict["content"].strip() - self.insert_or_create_section(article_dict=content_dict["subsections"], parent_section_name=section_name, - trim_children=True) + self.insert_or_create_section( + article_dict=content_dict["subsections"], + parent_section_name=section_name, + trim_children=True, + ) - def update_section(self, - current_section_content: str, - current_section_info_list: List[StormInformation], - parent_section_name: Optional[str] = None) -> Optional[ArticleSectionNode]: + def update_section( + self, + current_section_content: str, + current_section_info_list: List[StormInformation], + parent_section_name: Optional[str] = None, + ) -> Optional[ArticleSectionNode]: """ - Add new section to the article. + Add new section to the article. Args: current_section_name: new section heading name in string format. parent_section_name: under which parent section to add the new one. Default to root. - current_section_content: optional section content. - + current_section_content: optional section content. + Returns: the ArticleSectionNode for current section if successfully created / updated. Otherwise none. """ if current_section_info_list is not None: - references = set([int(x) for x in re.findall(r'\[(\d+)\]', current_section_content)]) + references = set( + [int(x) for x in re.findall(r"\[(\d+)\]", current_section_content)] + ) # for any reference number greater than max number of references, delete the reference if len(references) > 0: max_ref_num = max(references) if max_ref_num > len(current_section_info_list): for i in range(len(current_section_info_list), max_ref_num + 1): - current_section_content = current_section_content.replace(f'[{i}]', '') + current_section_content = current_section_content.replace( + f"[{i}]", "" + ) if i in references: references.remove(i) # for any reference that is not used, trim it from current_section_info_list index_to_keep = [i - 1 for i in references] - citation_mapping = self._merge_new_info_to_references(current_section_info_list, index_to_keep) - current_section_content = ArticleTextProcessing.update_citation_index(current_section_content, - citation_mapping) + citation_mapping = self._merge_new_info_to_references( + current_section_info_list, index_to_keep + ) + current_section_content = ArticleTextProcessing.update_citation_index( + current_section_content, citation_mapping + ) if parent_section_name is None: parent_section_name = self.root.section_name - article_dict = ArticleTextProcessing.parse_article_into_dict(current_section_content) - self.insert_or_create_section(article_dict=article_dict, parent_section_name=parent_section_name, - trim_children=False) + article_dict = ArticleTextProcessing.parse_article_into_dict( + current_section_content + ) + self.insert_or_create_section( + article_dict=article_dict, + parent_section_name=parent_section_name, + trim_children=False, + ) - def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hashtags: bool = False, - include_root: bool = True) -> List[str]: + def get_outline_as_list( + self, + root_section_name: Optional[str] = None, + add_hashtags: bool = False, + include_root: bool = True, + ) -> List[str]: """ Get outline of the article as a list. @@ -320,7 +375,7 @@ def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hasht ###section1.2 ##section2 article.get_outline_as_list("section1") returns [section1, section1.1, section1.2, section2] - + Returns: list of section and subsection names. """ @@ -334,8 +389,14 @@ def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hasht result = [] def preorder_traverse(node, level): - prefix = "#" * level if add_hashtags else "" # Adjust level if excluding root - result.append(f"{prefix} {node.section_name}".strip() if add_hashtags else node.section_name) + prefix = ( + "#" * level if add_hashtags else "" + ) # Adjust level if excluding root + result.append( + f"{prefix} {node.section_name}".strip() + if add_hashtags + else node.section_name + ) for child in node.children: preorder_traverse(child, level + 1) @@ -350,7 +411,7 @@ def preorder_traverse(node, level): def to_string(self) -> str: """ Get outline of the article as a list. - + Returns: list of section and subsection names. """ @@ -376,7 +437,9 @@ def reorder_reference_index(self): def pre_order_find_index(node): if node is not None: if node.content is not None and node.content: - ref_indices.extend(ArticleTextProcessing.parse_citation_indices(node.content)) + ref_indices.extend( + ArticleTextProcessing.parse_citation_indices(node.content) + ) for child in node.children: pre_order_find_index(child) @@ -391,7 +454,9 @@ def pre_order_find_index(node): def pre_order_update_index(node): if node is not None: if node.content is not None and node.content: - node.content = ArticleTextProcessing.update_citation_index(node.content, ref_index_mapping) + node.content = ArticleTextProcessing.update_citation_index( + node.content, ref_index_mapping + ) for child in node.children: pre_order_update_index(child) @@ -442,18 +507,18 @@ def from_outline_str(cls, topic: str, outline_str: str): instance = cls(topic) if lines: - a = lines[0].startswith('#') and lines[0].replace('#', '').strip().lower() + a = lines[0].startswith("#") and lines[0].replace("#", "").strip().lower() b = topic.lower().replace("_", " ") - adjust_level = lines[0].startswith('#') and lines[0].replace('#', - '').strip().lower() == topic.lower().replace( - "_", " ") + adjust_level = lines[0].startswith("#") and lines[0].replace( + "#", "" + ).strip().lower() == topic.lower().replace("_", " ") if adjust_level: lines = lines[1:] node_stack = [(0, instance.root)] # Stack to keep track of (level, node) for line in lines: - level = line.count('#') - adjust_level - section_name = line.replace('#', '').strip() + level = line.count("#") - adjust_level + section_name = line.replace("#", "").strip() if section_name == topic: continue @@ -487,7 +552,9 @@ def from_string(cls, topic_name: str, article_text: str, references: dict): article = cls(topic_name=topic_name) article.insert_or_create_section(article_dict=article_dict) for url in list(references["url_to_info"]): - references["url_to_info"][url] = StormInformation.from_dict(references["url_to_info"][url]) + references["url_to_info"][url] = StormInformation.from_dict( + references["url_to_info"][url] + ) article.reference = references return article diff --git a/knowledge_storm/utils.py b/knowledge_storm/utils.py index 36f68fd..1749609 100644 --- a/knowledge_storm/utils.py +++ b/knowledge_storm/utils.py @@ -20,9 +20,10 @@ logging.getLogger("httpx").setLevel(logging.WARNING) # Disable INFO logging for httpx. + def truncate_filename(filename, max_length=125): """Truncate filename to max_length to ensure the filename won't exceed the file system limit. - + Args: filename: str max_length: int, default to 125 (usual path length limit is 255 chars) @@ -30,14 +31,17 @@ def truncate_filename(filename, max_length=125): if len(filename) > max_length: truncated_filename = filename[:max_length] - logging.warning(f"Filename is too long. Filename is truncated to {truncated_filename}.") + logging.warning( + f"Filename is too long. Filename is truncated to {truncated_filename}." + ) return truncated_filename return filename + def load_api_key(toml_file_path): try: - with open(toml_file_path, 'r') as file: + with open(toml_file_path, "r") as file: data = toml.load(file) except FileNotFoundError: print(f"File not found: {toml_file_path}", file=sys.stderr) @@ -57,12 +61,15 @@ def makeStringRed(message): class QdrantVectorStoreManager: """ Helper class for managing the Qdrant vector store, can be used with `VectorRM` in rm.py. - + Before you initialize `VectorRM`, call `create_or_update_vector_store` to create or update the vector store. Once you have the vector store, you can initialize `VectorRM` with the vector store path or the Qdrant server URL. """ + @staticmethod - def _check_create_collection(client: QdrantClient, collection_name: str, model: HuggingFaceEmbeddings): + def _check_create_collection( + client: QdrantClient, collection_name: str, model: HuggingFaceEmbeddings + ): """Check if the Qdrant collection exists and create it if it does not.""" if client is None: raise ValueError("Qdrant client is not initialized.") @@ -74,20 +81,26 @@ def _check_create_collection(client: QdrantClient, collection_name: str, model: embeddings=model, ) else: - print(f"Collection {collection_name} does not exist. Creating the collection...") + print( + f"Collection {collection_name} does not exist. Creating the collection..." + ) # create the collection client.create_collection( collection_name=f"{collection_name}", - vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE), + vectors_config=models.VectorParams( + size=1024, distance=models.Distance.COSINE + ), ) return Qdrant( client=client, collection_name=collection_name, embeddings=model, ) - + @staticmethod - def _init_online_vector_db(url: str, api_key: str, collection_name: str, model: HuggingFaceEmbeddings): + def _init_online_vector_db( + url: str, api_key: str, collection_name: str, model: HuggingFaceEmbeddings + ): """Initialize the Qdrant client that is connected to an online vector store with the given URL and API key. Args: @@ -103,12 +116,16 @@ def _init_online_vector_db(url: str, api_key: str, collection_name: str, model: try: client = QdrantClient(url=url, api_key=api_key) - return QdrantVectorStoreManager._check_create_collection(client=client, collection_name=collection_name, model=model) + return QdrantVectorStoreManager._check_create_collection( + client=client, collection_name=collection_name, model=model + ) except Exception as e: raise ValueError(f"Error occurs when connecting to the server: {e}") @staticmethod - def _init_offline_vector_db(vector_store_path: str, collection_name: str, model: HuggingFaceEmbeddings): + def _init_offline_vector_db( + vector_store_path: str, collection_name: str, model: HuggingFaceEmbeddings + ): """Initialize the Qdrant client that is connected to an offline vector store with the given vector store folder path. Args: @@ -119,37 +136,39 @@ def _init_offline_vector_db(vector_store_path: str, collection_name: str, model: try: client = QdrantClient(path=vector_store_path) - return QdrantVectorStoreManager._check_create_collection(client=client, collection_name=collection_name, model=model) + return QdrantVectorStoreManager._check_create_collection( + client=client, collection_name=collection_name, model=model + ) except Exception as e: raise ValueError(f"Error occurs when loading the vector store: {e}") - + @staticmethod def create_or_update_vector_store( - collection_name: str, - vector_db_mode: str, - file_path: str, - content_column: str, - title_column: str = "title", - url_column: str = "url", - desc_column: str = "description", - batch_size: int = 64, - chunk_size: int = 500, - chunk_overlap: int = 100, - vector_store_path: str = None, - url: str = None, - qdrant_api_key: str = None, - embedding_model: str = 'BAAI/bge-m3', - device: str = "mps", + collection_name: str, + vector_db_mode: str, + file_path: str, + content_column: str, + title_column: str = "title", + url_column: str = "url", + desc_column: str = "description", + batch_size: int = 64, + chunk_size: int = 500, + chunk_overlap: int = 100, + vector_store_path: str = None, + url: str = None, + qdrant_api_key: str = None, + embedding_model: str = "BAAI/bge-m3", + device: str = "mps", ): """ Takes a CSV file and adds each row in the CSV file to the Qdrant collection. - + This function expects each row of the CSV file as a document. The CSV file should have columns for "content", "title", "URL", and "description". Args: collection_name: Name of the Qdrant collection. - vector_store_path (str): Path to the directory where the vector store is stored or will be stored. + vector_store_path (str): Path to the directory where the vector store is stored or will be stored. vector_db_mode (str): Mode of the Qdrant vector store (offline or online). file_path (str): Path to the CSV file. content_column (str): Name of the column containing the content. @@ -166,17 +185,19 @@ def create_or_update_vector_store( # check if the collection name is provided if collection_name is None: raise ValueError("Please provide a collection name.") - + model_kwargs = {"device": device} encode_kwargs = {"normalize_embeddings": True} model = HuggingFaceEmbeddings( - model_name=embedding_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs + model_name=embedding_model, + model_kwargs=model_kwargs, + encode_kwargs=encode_kwargs, ) if file_path is None: raise ValueError("Please provide a file path.") # check if the file is a csv file - if not file_path.endswith('.csv'): + if not file_path.endswith(".csv"): raise ValueError(f"Not valid file format. Please provide a csv file.") if content_column is None: raise ValueError("Please provide the name of the content column.") @@ -185,17 +206,23 @@ def create_or_update_vector_store( # try to initialize the Qdrant client qdrant = None - if vector_db_mode == 'online': + if vector_db_mode == "online": qdrant = QdrantVectorStoreManager._init_online_vector_db( url=url, api_key=qdrant_api_key, collection_name=collection_name, model=model, ) - elif vector_db_mode == 'offline': - qdrant = QdrantVectorStoreManager._init_offline_vector_db(vector_store_path=vector_store_path, collection_name=collection_name, model=model) + elif vector_db_mode == "offline": + qdrant = QdrantVectorStoreManager._init_offline_vector_db( + vector_store_path=vector_store_path, + collection_name=collection_name, + model=model, + ) else: - raise ValueError("Invalid vector_db_mode. Please provide either 'online' or 'offline'.") + raise ValueError( + "Invalid vector_db_mode. Please provide either 'online' or 'offline'." + ) if qdrant is None: raise ValueError("Qdrant client is not initialized.") @@ -203,7 +230,9 @@ def create_or_update_vector_store( df = pd.read_csv(file_path) # check that content column exists and url column exists if content_column not in df.columns: - raise ValueError(f"Content column {content_column} not found in the csv file.") + raise ValueError( + f"Content column {content_column} not found in the csv file." + ) if url_column not in df.columns: raise ValueError(f"URL column {url_column} not found in the csv file.") @@ -211,16 +240,17 @@ def create_or_update_vector_store( Document( page_content=row[content_column], metadata={ - "title": row.get(title_column, ''), + "title": row.get(title_column, ""), "url": row[url_column], - "description": row.get(desc_column, ''), - } + "description": row.get(desc_column, ""), + }, ) - for row in df.to_dict(orient='records') + for row in df.to_dict(orient="records") ] # split the documents from langchain_text_splitters import RecursiveCharacterTextSplitter + text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, @@ -238,7 +268,7 @@ def create_or_update_vector_store( " ", "\u200B", # Zero-width space "", - ] + ], ) split_documents = text_splitter.split_documents(documents) @@ -251,7 +281,7 @@ def create_or_update_vector_store( documents=split_documents[start_idx:end_idx], batch_size=batch_size, ) - + # close the qdrant client qdrant.client.close() @@ -275,19 +305,19 @@ def limit_word_count_preserve_newline(input_string, max_word_count): """ word_count = 0 - limited_string = '' + limited_string = "" - for word in input_string.split('\n'): + for word in input_string.split("\n"): line_words = word.split() for lw in line_words: if word_count < max_word_count: - limited_string += lw + ' ' + limited_string += lw + " " word_count += 1 else: break if word_count >= max_word_count: break - limited_string = limited_string.strip() + '\n' + limited_string = limited_string.strip() + "\n" return limited_string.strip() @@ -305,7 +335,7 @@ def remove_citations(s): str: The string with all citation patterns removed. """ - return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s) + return re.sub(r"\[\d+(?:,\s*\d+)*\]", "", s) @staticmethod def parse_citation_indices(s): @@ -318,7 +348,7 @@ def parse_citation_indices(s): Returns: List[int]: A list of unique citation indexes extracted from the content, in the order they appear. """ - matches = re.findall(r'\[\d+\]', s) + matches = re.findall(r"\[\d+\]", s) return [int(index[1:-1]) for index in matches] @staticmethod @@ -339,19 +369,21 @@ def remove_uncompleted_sentences_with_citations(text): # Convert citations like [1, 2, 3] to [1][2][3]. def replace_with_individual_brackets(match): - numbers = match.group(1).split(', ') - return ' '.join(f'[{n}]' for n in numbers) + numbers = match.group(1).split(", ") + return " ".join(f"[{n}]" for n in numbers) # Deduplicate and sort individual groups of citations. def deduplicate_group(match): citations = match.group(0) - unique_citations = list(set(re.findall(r'\[\d+\]', citations))) - sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]'))) + unique_citations = list(set(re.findall(r"\[\d+\]", citations))) + sorted_citations = sorted( + unique_citations, key=lambda x: int(x.strip("[]")) + ) # Return the sorted unique citations as a string - return ''.join(sorted_citations) + return "".join(sorted_citations) - text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text) - text = re.sub(r'(\[\d+\])+', deduplicate_group, text) + text = re.sub(r"\[([0-9, ]+)\]", replace_with_individual_brackets, text) + text = re.sub(r"(\[\d+\])+", deduplicate_group, text) # Deprecated: Remove sentence without proper ending punctuation and citations. # Split the text into sentences (including citations). @@ -372,29 +404,38 @@ def deduplicate_group(match): # combined_sentences += ' '.join(trailing_citations) # Regex pattern to match sentence endings, including optional citation markers. - eos_pattern = r'([.!?])\s*(\[\d+\])?\s*' + eos_pattern = r"([.!?])\s*(\[\d+\])?\s*" matches = list(re.finditer(eos_pattern, text)) if matches: last_match = matches[-1] - text = text[:last_match.end()].strip() + text = text[: last_match.end()].strip() return text @staticmethod def clean_up_citation(conv): for turn in conv.dlg_history: - turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')] - turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')] - turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip() + turn.agent_utterance = turn.agent_utterance[ + : turn.agent_utterance.find("References:") + ] + turn.agent_utterance = turn.agent_utterance[ + : turn.agent_utterance.find("Sources:") + ] + turn.agent_utterance = turn.agent_utterance.replace("Answer:", "").strip() try: - max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)]) + max_ref_num = max( + [int(x) for x in re.findall(r"\[(\d+)\]", turn.agent_utterance)] + ) except Exception as e: max_ref_num = 0 if max_ref_num > len(turn.search_results): for i in range(len(turn.search_results), max_ref_num + 1): - turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '') - turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( - turn.agent_utterance) + turn.agent_utterance = turn.agent_utterance.replace(f"[{i}]", "") + turn.agent_utterance = ( + ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + turn.agent_utterance + ) + ) return conv @@ -403,36 +444,46 @@ def clean_up_outline(outline, topic=""): output_lines = [] current_level = 0 # To track the current section level - for line in outline.split('\n'): + for line in outline.split("\n"): stripped_line = line.strip() if topic != "" and f"# {topic.lower()}" in stripped_line.lower(): output_lines = [] # Check if the line is a section header - if stripped_line.startswith('#'): - current_level = stripped_line.count('#') + if stripped_line.startswith("#"): + current_level = stripped_line.count("#") output_lines.append(stripped_line) # Check if the line is a bullet point - elif stripped_line.startswith('-'): - subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip() + elif stripped_line.startswith("-"): + subsection_header = ( + "#" * (current_level + 1) + " " + stripped_line[1:].strip() + ) output_lines.append(subsection_header) - outline = '\n'.join(output_lines) + outline = "\n".join(output_lines) # Remove references. - outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See also.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See Also.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Notes.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? References.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub( + r"#[#]? External links.*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub( + r"#[#]? External Links.*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub( + r"#[#]? Further reading*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub( + r"#[#]? Further Reading*?(?=##|$)", "", outline, flags=re.DOTALL + ) + outline = re.sub(r"#[#]? Summary.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", "", outline, flags=re.DOTALL) return outline @@ -443,34 +494,40 @@ def clean_up_section(text): 2. Deduplicate individual groups of citations. 3. Remove unnecessary summary.""" - paragraphs = text.split('\n') + paragraphs = text.split("\n") output_paragraphs = [] summary_sec_flag = False for p in paragraphs: p = p.strip() if len(p) == 0: continue - if not p.startswith('#'): + if not p.startswith("#"): p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p) if summary_sec_flag: - if p.startswith('#'): + if p.startswith("#"): summary_sec_flag = False else: continue - if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'): + if ( + p.startswith("Overall") + or p.startswith("In summary") + or p.startswith("In conclusion") + ): continue - if "# Summary" in p or '# Conclusion' in p: + if "# Summary" in p or "# Conclusion" in p: summary_sec_flag = True continue output_paragraphs.append(p) - return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format. + return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. @staticmethod def update_citation_index(s, citation_map): """Update citation index in the string based on the citation map.""" for original_citation in citation_map: - s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__") + s = s.replace( + f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__" + ) for original_citation, unify_citation in citation_map.items(): s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]") @@ -497,34 +554,34 @@ def parse_article_into_dict(input_string): A dictionary representing contains the section title as the key, and another dictionary as the value, which includes the 'content' and 'subsections' keys as described above. """ - lines = input_string.split('\n') + lines = input_string.split("\n") lines = [line for line in lines if line.strip()] - root = {'content': '', 'subsections': {}} + root = {"content": "", "subsections": {}} current_path = [(root, -1)] # (current_dict, level) for line in lines: - if line.startswith('#'): - level = line.count('#') - title = line.strip('# ').strip() - new_section = {'content': '', 'subsections': {}} + if line.startswith("#"): + level = line.count("#") + title = line.strip("# ").strip() + new_section = {"content": "", "subsections": {}} # Pop from stack until find the parent level while current_path and current_path[-1][1] >= level: current_path.pop() # Append new section to the nearest upper level's subsections - current_path[-1][0]['subsections'][title] = new_section + current_path[-1][0]["subsections"][title] = new_section current_path.append((new_section, level)) else: - current_path[-1][0]['content'] += line + '\n' + current_path[-1][0]["content"] += line + "\n" - return root['subsections'] + return root["subsections"] class FileIOHelper: @staticmethod def dump_json(obj, file_name, encoding="utf-8"): - with open(file_name, 'w', encoding=encoding) as fw: + with open(file_name, "w", encoding=encoding) as fw: json.dump(obj, fw, default=FileIOHelper.handle_non_serializable) @staticmethod @@ -533,27 +590,27 @@ def handle_non_serializable(obj): @staticmethod def load_json(file_name, encoding="utf-8"): - with open(file_name, 'r', encoding=encoding) as fr: + with open(file_name, "r", encoding=encoding) as fr: return json.load(fr) @staticmethod def write_str(s, path): - with open(path, 'w') as f: + with open(path, "w") as f: f.write(s) @staticmethod def load_str(path): - with open(path, 'r') as f: - return '\n'.join(f.readlines()) + with open(path, "r") as f: + return "\n".join(f.readlines()) @staticmethod def dump_pickle(obj, path): - with open(path, 'wb') as f: + with open(path, "wb") as f: pickle.dump(obj, f) @staticmethod def load_pickle(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return pickle.load(f) @@ -563,7 +620,12 @@ class WebPageHelper: Acknowledgement: Part of the code is adapted from https://github.com/stanford-oval/WikiChat project. """ - def __init__(self, min_char_count: int = 150, snippet_chunk_size: int = 1000, max_thread_num: int = 10): + def __init__( + self, + min_char_count: int = 150, + snippet_chunk_size: int = 1000, + max_thread_num: int = 10, + ): """ Args: min_char_count: Minimum character count for the article to be considered valid. @@ -604,7 +666,9 @@ def download_webpage(self, url: str): return None def urls_to_articles(self, urls: List[str]) -> Dict: - with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_thread_num + ) as executor: htmls = list(executor.map(self.download_webpage, urls)) articles = {}