Skip to content

Commit

Permalink
Merge pull request #193 from imoneoi/moe-and-gemma-support
Browse files Browse the repository at this point in the history
MoE and gemma support
  • Loading branch information
imoneoi committed Mar 9, 2024
2 parents 7fbad42 + 3783903 commit 4f4e426
Show file tree
Hide file tree
Showing 9 changed files with 553 additions and 47 deletions.
51 changes: 51 additions & 0 deletions ochat/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,31 @@
}


_GEMMA_IT_PREFIXES = {
"user": "user",
"assistant": "model"
}


def _v3_2_role_prefix(from_role, condition):
return f"{condition} {_V3_2_PREFIXES[from_role]}".strip()


MODEL_CONFIG_MAP = {
# OpenChat V3.6 (MoE)
"openchat_3.6": ModelConfig(
# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
model_create_for_training=lambda: None, # NOTE(one): MoE trainer decoupled from the codebase

# Conversation Template
conversation_template=partial(ConversationTemplate,
role_prefix=_v3_2_role_prefix,
eot="</s>",
inference_condition="GPT4 Correct")
),

# OpenChat V3.2
"openchat_v3.2": ModelConfig(
# Model
Expand Down Expand Up @@ -54,6 +74,23 @@ def _v3_2_role_prefix(from_role, condition):
inference_condition="GPT4 Correct")
),

"openchat_v3.2_gemma_new": ModelConfig(
serving_aliases=("openchat_3.5_gemma_new", ),

# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),

# Conversation Template
conversation_template=partial(ConversationTemplate,
role_prefix=_v3_2_role_prefix,
eot="<end_of_turn>",
inference_condition="GPT4 Correct")
),

### Other models
"chatml_mistral": ModelConfig(
# Model
Expand Down Expand Up @@ -83,4 +120,18 @@ def _v3_2_role_prefix(from_role, condition):
eot="</s>",
inference_condition="")
),
"gemma_it": ModelConfig(
# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),

# Conversation Template
conversation_template=partial(ConversationTemplate,
role_prefix=lambda from_role, condition: f"<start_of_turn>{_GEMMA_IT_PREFIXES[from_role]}\n",
eot="<end_of_turn>",
inference_condition="")
),
}
74 changes: 48 additions & 26 deletions ochat/data/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.jsonl --tokenizer-name HF_REPO_NAME --out-dir .
"""

from typing import Optional
import argparse
import os
import gc
import random

import ray
Expand Down Expand Up @@ -77,6 +77,9 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc
print ("Tokenizing ...")
tokens_list, weights_list = conv_template.tokenize_conversations(batch, inference=False, seq_level_weight=per_sequence_loss)

del batch
gc.collect()

# Generate data
print ("Generating ...")
max_context = model_config.model_max_context
Expand All @@ -92,12 +95,20 @@ def convert_conversation_batch(model_type: str, model_path: str, batch: list, sc
# Add to results
add_single_conv(outputs, tokens, weights)

print ("Chunk finish")
del tokens_list, weights_list
gc.collect()

return pyarrow.Table.from_pydict(outputs, schema=schema)
print ("To table ...")
table = pyarrow.Table.from_pydict(outputs, schema=schema)

del outputs
gc.collect()

print ("Chunk finish")
return table

def generate_split(model_type: str, model_path: str, conversations: list, split_name: str, out_prefix: str, per_sequence_loss: bool):

def generate_epoch(seed: int, model_type: str, model_path: str, in_filename: str, out_filename: str, per_sequence_loss: bool):
# schema
metadata = {
"model_type": model_type
Expand All @@ -115,53 +126,64 @@ def generate_split(model_type: str, model_path: str, conversations: list, split_

schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)})

# launch remote workers
if not ray.is_initialized():
ray.init(ignore_reinit_error=True, num_cpus=os.cpu_count())
# Load data
with open(in_filename, "rb") as f:
batches = f.readlines()

random.seed(seed) # Randomized load balancing
random.shuffle(batches)

batches = _split(batches, int(ray.available_resources()["CPU"]))

# launch remote workers
handles = [convert_conversation_batch.remote(
model_type=model_type, # type: ignore
model_path=model_path,
batch=batch,
schema=schema,
per_sequence_loss=per_sequence_loss
) for batch in _split(conversations, int(ray.available_resources()["CPU"]))]
) for batch in batches]

# write
parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), f"{out_prefix}.{split_name}.parquet")
parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), out_filename)


def generate_dataset(model_type, model_path, in_files, out_prefix, per_sequence_loss, seed, eval_ratio):
# Load conversations
conversations = []
for filename in in_files:
with open(filename, "rt") as f:
conversations.extend(f.readlines())
def generate_dataset(model_type, model_path, in_prefix, out_prefix, per_sequence_loss, seed):
# Initialize Ray
if not ray.is_initialized():
ray.init(ignore_reinit_error=True, num_cpus=os.cpu_count())

# Train-test split
random.seed(seed)
random.shuffle(conversations)
eval_num = int(eval_ratio * len(conversations))
# Load epochs and tokenize
epoch = 0
while True:
in_filename = f"{in_prefix}.{epoch}.jsonl"
if not os.path.exists(in_filename):
break

train_conversations = conversations[eval_num:]
eval_conversations = conversations[:eval_num]
out_filename = f"{out_prefix}.{epoch}.parquet"
generate_epoch(
seed=seed + epoch,
model_type=model_type,
model_path=model_path,
in_filename=in_filename,
out_filename=out_filename,
per_sequence_loss=per_sequence_loss
)
gc.collect()

generate_split(model_type, model_path, train_conversations, "train", out_prefix, per_sequence_loss)
if eval_num > 0:
generate_split(model_type, model_path, eval_conversations, "eval", out_prefix, per_sequence_loss)
epoch += 1


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-type", type=str, required=True)
parser.add_argument("--model-path", type=str, required=True)

parser.add_argument("--in-files", type=str, nargs="+", required=True)
parser.add_argument("--in-prefix", type=str, required=True)
parser.add_argument("--out-prefix", type=str, required=True)

parser.add_argument("--per-sequence-loss", action="store_true")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--eval-ratio", type=float, default=0.005)
args = parser.parse_args()

generate_dataset(**vars(args))
11 changes: 5 additions & 6 deletions ochat/evaluation/match_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def fs_cothub_bbh_match_answer(task_data, response):
return False, ans
else:
# Free form, direct return
if ans[-1] == '.':
if len(ans) and ans[-1] == '.':
ans = ans[:-1]

return True, ans
Expand Down Expand Up @@ -155,12 +155,11 @@ def _function_exists(code, func_name):
return False

def _try_match(content, prefix, entrypoint):
for block in content.split("```"):
# Sanitize block
block = block.strip()
if block.startswith("python"):
block = block[len("python"):]
# All markdown code blocks, as well as raw
code_blocks = [m[1] for m in re.findall(r"(\`{3}.*?\n+)([\s\S]*?)(\n+\`{3})", content)] \
+ [content]

for block in code_blocks:
# Check syntax
try:
code_completion = prefix + block
Expand Down
23 changes: 16 additions & 7 deletions ochat/evaluation/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from ochat.config import MODEL_CONFIG_MAP


def _strip_first_space(s: str):
if len(s) and s[0] == " ":
return s[1:]
return s


@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(20), retry=retry_if_exception_type((RateLimitError, ServiceUnavailableError, )))
async def _chat_completion_with_backoff(**kwargs):
return await openai.ChatCompletion.acreate(**kwargs)
Expand Down Expand Up @@ -122,7 +128,8 @@ def get_model_answers(
questions: list,
condition: str,
system_msg: str,
model_type: str
model_type: str,
tensor_parallel_size: int
):
# Load model config
if model_type is None:
Expand All @@ -136,9 +143,10 @@ def get_model_answers(
# Init vLLM engine
engine = LLM(model,
max_num_batched_tokens=model_config.model_max_context,
max_model_len=model_config.model_max_context)
max_model_len=model_config.model_max_context,
tensor_parallel_size=tensor_parallel_size)
sampling_params = SamplingParams(temperature=0,
max_tokens=model_config.model_max_context,
max_tokens=None,
stop_token_ids=conv_template.eot_tokens_, # Override stop tokens
ignore_eos=True)

Expand All @@ -149,8 +157,7 @@ def get_model_answers(
# calculate & fill in responses
responses = engine.generate(prompt_token_ids=prompts, sampling_params=sampling_params)
for idx, resp in zip(prompt_indices, responses):
questions[idx]["response"] = resp.outputs[0].text

questions[idx]["response"] = _strip_first_space(resp.outputs[0].text)

return questions

Expand All @@ -167,7 +174,8 @@ async def run_eval(
continue_from: Optional[str],
output_file: str,

parallel: int
parallel: int,
tensor_parallel_size: int
):
print (f"Evaluating ({model_type})...\n\nCondition: {condition}\nSystem Prompt: {system_msg}\n")

Expand Down Expand Up @@ -201,7 +209,7 @@ async def run_eval(
if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
questions = await get_openai_answers(model, questions, parallel)
else:
questions = get_model_answers(model, questions, condition, system_msg, model_type)
questions = get_model_answers(model, questions, condition, system_msg, model_type, tensor_parallel_size)

# Calculate accuracy
for q in questions:
Expand Down Expand Up @@ -235,6 +243,7 @@ async def main():
parser.add_argument("--continue_from", type=str, default=None)
parser.add_argument("--output_file", type=str, default=None)
parser.add_argument("--parallel", type=int, default=16)
parser.add_argument("--tensor-parallel-size", type=int, default=1)

args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions ochat/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ochat.models.unpadded_llama import LlamaForCausalLM
from ochat.models.unpadded_mistral import MistralForCausalLM
from ochat.models.unpadded_gemma import GemmaForCausalLM
Loading

0 comments on commit 4f4e426

Please sign in to comment.