Skip to content

Commit

Permalink
Reformat knowledge_storm/** using black.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyijia committed Sep 1, 2024
1 parent ded9687 commit 89c8aad
Show file tree
Hide file tree
Showing 13 changed files with 1,368 additions and 827 deletions.
2 changes: 1 addition & 1 deletion knowledge_storm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .storm_wiki.engine import (
STORMWikiLMConfigs,
STORMWikiRunnerArguments,
STORMWikiRunner
STORMWikiRunner,
)

__version__ = "0.2.5"
69 changes: 47 additions & 22 deletions knowledge_storm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Union

logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s')
logging.basicConfig(
level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s"
)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,7 +72,9 @@ class Article(ABC):
def __init__(self, topic_name):
self.root = ArticleSectionNode(topic_name)

def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]:
def find_section(
self, node: ArticleSectionNode, name: str
) -> Optional[ArticleSectionNode]:
"""
Return the node of the section given the section name.
Expand Down Expand Up @@ -152,7 +156,9 @@ def prune_empty_nodes(self, node=None):
if node is None:
node = self.root

node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)]
node.children[:] = [
child for child in node.children if self.prune_empty_nodes(child)
]

if (node.content is None or node.content == "") and not node.children:
return None
Expand All @@ -178,7 +184,9 @@ def update_search_top_k(self, k):
def collect_and_reset_rm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
if "_rm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

name_to_usage = {}
Expand Down Expand Up @@ -240,7 +248,9 @@ class OutlineGenerationModule(ABC):
"""

@abstractmethod
def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article:
def generate_outline(
self, topic: str, information_table: InformationTable, **kwargs
) -> Article:
"""
Generate outline for the article. Required arguments include:
topic: the topic of interest
Expand All @@ -263,11 +273,13 @@ class ArticleGenerationModule(ABC):
"""

@abstractmethod
def generate_article(self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs) -> Article:
def generate_article(
self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs,
) -> Article:
"""
Generate article. Required arguments include:
topic: the topic of interest
Expand Down Expand Up @@ -312,22 +324,23 @@ def wrapper(self, *args, **kwargs):
class LMConfigs(ABC):
"""Abstract base class for language model configurations of the knowledge curation engine.
The language model used for each part should be declared with a suffix '_lm' in the attribute name."""
The language model used for each part should be declared with a suffix '_lm' in the attribute name.
"""

def __init__(self):
pass

def init_check(self):
for attr_name in self.__dict__:
if '_lm' in attr_name and getattr(self, attr_name) is None:
if "_lm" in attr_name and getattr(self, attr_name) is None:
logging.warning(
f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()"
)

def collect_and_reset_lm_history(self):
history = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'):
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"):
history.extend(getattr(self, attr_name).history)
getattr(self, attr_name).history = []

Expand All @@ -336,7 +349,9 @@ def collect_and_reset_lm_history(self):
def collect_and_reset_lm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
if "_lm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

model_name_to_usage = {}
Expand All @@ -345,17 +360,22 @@ def collect_and_reset_lm_usage(self):
if model_name not in model_name_to_usage:
model_name_to_usage[model_name] = tokens
else:
model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens']
model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens']
model_name_to_usage[model_name]["prompt_tokens"] += tokens[
"prompt_tokens"
]
model_name_to_usage[model_name]["completion_tokens"] += tokens[
"completion_tokens"
]

return model_name_to_usage

def log(self):

return OrderedDict(
{
attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if
'_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs')
attr_name: getattr(self, attr_name).kwargs
for attr_name in self.__dict__
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs")
}
)

Expand All @@ -379,16 +399,21 @@ def wrapper(*args, **kwargs):
self.time[func.__name__] = execution_time
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds")
self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage()
if hasattr(self, 'retriever'):
self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage()
if hasattr(self, "retriever"):
self.rm_cost[func.__name__] = (
self.retriever.collect_and_reset_rm_usage()
)
return result

return wrapper

def apply_decorators(self):
"""Apply decorators to methods that need them."""
methods_to_decorate = [method_name for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith('run_')]
methods_to_decorate = [
method_name
for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith("run_")
]
for method_name in methods_to_decorate:
original_method = getattr(self, method_name)
decorated_method = self.log_execution_time_and_lm_rm_usage(original_method)
Expand Down
Loading

0 comments on commit 89c8aad

Please sign in to comment.