From 11e5d1fdb52aecf71c78ef6eae7c77e8b85d1537 Mon Sep 17 00:00:00 2001 From: Ziqing Yang Date: Mon, 8 May 2023 18:17:13 +0800 Subject: [PATCH] fix saving to mulitple shards --- scripts/merge_llama_with_chinese_lora.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/merge_llama_with_chinese_lora.py b/scripts/merge_llama_with_chinese_lora.py index 0de5c50..c6fc73b 100644 --- a/scripts/merge_llama_with_chinese_lora.py +++ b/scripts/merge_llama_with_chinese_lora.py @@ -134,8 +134,12 @@ def save_shards(model_sd, num_shards: int): splits = v.split(v.size(1)//num_shards,dim=1) elif new_k=='output.weight': print(f"Processing {new_k}") - splits = v.split(v.size(0)//num_shards,dim=0) - + if v.size(0)%num_shards==0: + splits = v.split(v.size(0)//num_shards,dim=0) + else: + size_list = [v.size(0)//num_shards] * num_shards + size_list[-1] += v.size(0)%num_shards + splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977] elif new_k=='norm.weight': print(f"Processing {new_k}") splits = [v] * num_shards