Skip to content

Commit

Permalink
fix saving to mulitple shards
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed May 8, 2023
1 parent 08d2ee2 commit 11e5d1f
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions scripts/merge_llama_with_chinese_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@ def save_shards(model_sd, num_shards: int):
splits = v.split(v.size(1)//num_shards,dim=1)
elif new_k=='output.weight':
print(f"Processing {new_k}")
splits = v.split(v.size(0)//num_shards,dim=0)

if v.size(0)%num_shards==0:
splits = v.split(v.size(0)//num_shards,dim=0)
else:
size_list = [v.size(0)//num_shards] * num_shards
size_list[-1] += v.size(0)%num_shards
splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977]
elif new_k=='norm.weight':
print(f"Processing {new_k}")
splits = [v] * num_shards
Expand Down

0 comments on commit 11e5d1f

Please sign in to comment.