Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use black as the python code formatter. #159

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
name: Format Python code with black
entry: black
args: ["knowledge_storm/"]
language: python
pass_filenames: true
7 changes: 6 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@ Following the suggested format can lead to a faster review process.

**Code Format:**

We adopt [PEP8 rules](https://peps.python.org/pep-0008/) for arranging and formatting Python code. Please use a code formatter tool in your IDE to reformat the code before submitting the PR.
We adopt [`black`](https://github.com/psf/black) for arranging and formatting Python code. To streamline the contribution process, we set up a [pre-commit hook](https://pre-commit.com/) to format the code under `knowledge_storm/` before committing. To install the pre-commit hook, run:
```
pip install pre-commit
pre-commit install
```
The hook will automatically format the code before each commit.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
- [2024/05] We add Bing Search support in [rm.py](knowledge_storm/rm.py). Test STORM with `GPT-4o` - we now configure the article generation part in our demo using `GPT-4o` model.
- [2024/04] We release refactored version of STORM codebase! We define [interface](knowledge_storm/interface.py) for STORM pipeline and reimplement STORM-wiki (check out [`src/storm_wiki`](knowledge_storm/storm_wiki)) to demonstrate how to instantiate the pipeline. We provide API to support customization of different language models and retrieval/search integration.

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

## Overview [(Try STORM now!)](https://storm.genie.stanford.edu/)

<p align="center">
Expand Down
2 changes: 1 addition & 1 deletion knowledge_storm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .storm_wiki.engine import (
STORMWikiLMConfigs,
STORMWikiRunnerArguments,
STORMWikiRunner
STORMWikiRunner,
)

__version__ = "0.2.5"
69 changes: 47 additions & 22 deletions knowledge_storm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Union

logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s')
logging.basicConfig(
level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s"
)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,7 +72,9 @@ class Article(ABC):
def __init__(self, topic_name):
self.root = ArticleSectionNode(topic_name)

def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]:
def find_section(
self, node: ArticleSectionNode, name: str
) -> Optional[ArticleSectionNode]:
"""
Return the node of the section given the section name.

Expand Down Expand Up @@ -152,7 +156,9 @@ def prune_empty_nodes(self, node=None):
if node is None:
node = self.root

node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)]
node.children[:] = [
child for child in node.children if self.prune_empty_nodes(child)
]

if (node.content is None or node.content == "") and not node.children:
return None
Expand All @@ -178,7 +184,9 @@ def update_search_top_k(self, k):
def collect_and_reset_rm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
if "_rm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

name_to_usage = {}
Expand Down Expand Up @@ -240,7 +248,9 @@ class OutlineGenerationModule(ABC):
"""

@abstractmethod
def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article:
def generate_outline(
self, topic: str, information_table: InformationTable, **kwargs
) -> Article:
"""
Generate outline for the article. Required arguments include:
topic: the topic of interest
Expand All @@ -263,11 +273,13 @@ class ArticleGenerationModule(ABC):
"""

@abstractmethod
def generate_article(self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs) -> Article:
def generate_article(
self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs,
) -> Article:
"""
Generate article. Required arguments include:
topic: the topic of interest
Expand Down Expand Up @@ -312,22 +324,23 @@ def wrapper(self, *args, **kwargs):
class LMConfigs(ABC):
"""Abstract base class for language model configurations of the knowledge curation engine.

The language model used for each part should be declared with a suffix '_lm' in the attribute name."""
The language model used for each part should be declared with a suffix '_lm' in the attribute name.
"""

def __init__(self):
pass

def init_check(self):
for attr_name in self.__dict__:
if '_lm' in attr_name and getattr(self, attr_name) is None:
if "_lm" in attr_name and getattr(self, attr_name) is None:
logging.warning(
f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()"
)

def collect_and_reset_lm_history(self):
history = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'):
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"):
history.extend(getattr(self, attr_name).history)
getattr(self, attr_name).history = []

Expand All @@ -336,7 +349,9 @@ def collect_and_reset_lm_history(self):
def collect_and_reset_lm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
if "_lm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

model_name_to_usage = {}
Expand All @@ -345,17 +360,22 @@ def collect_and_reset_lm_usage(self):
if model_name not in model_name_to_usage:
model_name_to_usage[model_name] = tokens
else:
model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens']
model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens']
model_name_to_usage[model_name]["prompt_tokens"] += tokens[
"prompt_tokens"
]
model_name_to_usage[model_name]["completion_tokens"] += tokens[
"completion_tokens"
]

return model_name_to_usage

def log(self):

return OrderedDict(
{
attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if
'_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs')
attr_name: getattr(self, attr_name).kwargs
for attr_name in self.__dict__
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs")
}
)

Expand All @@ -379,16 +399,21 @@ def wrapper(*args, **kwargs):
self.time[func.__name__] = execution_time
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds")
self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage()
if hasattr(self, 'retriever'):
self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage()
if hasattr(self, "retriever"):
self.rm_cost[func.__name__] = (
self.retriever.collect_and_reset_rm_usage()
)
return result

return wrapper

def apply_decorators(self):
"""Apply decorators to methods that need them."""
methods_to_decorate = [method_name for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith('run_')]
methods_to_decorate = [
method_name
for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith("run_")
]
for method_name in methods_to_decorate:
original_method = getattr(self, method_name)
decorated_method = self.log_execution_time_and_lm_rm_usage(original_method)
Expand Down
Loading
Loading