Skip to content

Commit

Permalink
[Bug] Update api with generation_kargs (#614)
Browse files Browse the repository at this point in the history
* update api

* update generation_kwargs impl

---------

Co-authored-by: Leymore <[email protected]>
  • Loading branch information
tonysy and Leymore authored Nov 22, 2023
1 parent eb56fd6 commit 721a45c
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 17 deletions.
6 changes: 5 additions & 1 deletion opencompass/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class BaseModel:
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
"""

is_api: bool = False
Expand All @@ -27,7 +29,8 @@ def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None):
meta_template: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = dict()):
self.path = path
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
Expand All @@ -36,6 +39,7 @@ def __init__(self,
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.generation_kwargs = generation_kwargs

@abstractmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
Expand Down
6 changes: 5 additions & 1 deletion opencompass/models/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class BaseAPIModel(BaseModel):
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict().
"""

is_api: bool = True
Expand All @@ -37,7 +39,8 @@ def __init__(self,
query_per_second: int = 1,
retry: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None):
meta_template: Optional[Dict] = None,
generation_kwargs: Dict = dict()):
self.path = path
self.max_seq_len = max_seq_len
self.meta_template = meta_template
Expand All @@ -46,6 +49,7 @@ def __init__(self,
self.token_bucket = TokenBucket(query_per_second)
self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger()
self.generation_kwargs = generation_kwargs

@abstractmethod
def generate(self, inputs: List[PromptType],
Expand Down
21 changes: 9 additions & 12 deletions opencompass/models/lightllm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,22 @@ class LightllmAPI(BaseAPIModel):
is_api: bool = True

def __init__(
self,
path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
generation_kwargs: Optional[Dict] = None,
self,
path: str = 'LightllmAPI',
url: str = 'http://localhost:8080/generate',
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
generation_kwargs: Optional[Dict] = dict(),
):

super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template,
retry=retry)
retry=retry,
generation_kwargs=generation_kwargs)
self.logger = get_logger()
self.url = url
if generation_kwargs is not None:
self.generation_kwargs = generation_kwargs
else:
self.generation_kwargs = {}
self.do_sample = self.generation_kwargs.get('do_sample', False)
self.ignore_eos = self.generation_kwargs.get('ignore_eos', False)

Expand Down
1 change: 0 additions & 1 deletion opencompass/models/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
tm_model.create_instance() for i in range(concurrency)
]
self.generator_ids = [i + 1 for i in range(concurrency)]
self.generation_kwargs = dict()

def generate(
self,
Expand Down
1 change: 0 additions & 1 deletion opencompass/models/turbomind_tis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.tis_addr = tis_addr
self.generation_kwargs = dict()

def generate(
self,
Expand Down
2 changes: 1 addition & 1 deletion opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def inference(self,
entry, max_out_len=self.max_out_len)
generated = results

num_return_sequences = self.model.generation_kwargs.get(
num_return_sequences = self.model.get('generation_kwargs', {}).get(
'num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(
Expand Down

0 comments on commit 721a45c

Please sign in to comment.