-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][Feature] DPO #434
base: main
Are you sure you want to change the base?
[WIP][Feature] DPO #434
Conversation
xtuner/model/dpo.py
Outdated
self.use_varlen_attn = use_varlen_attn | ||
|
||
# TODO: Add ref model and ref model config | ||
self.ref_llm = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ref_llm, 也支持 api model
更新了 dpo 的实现,使用 sft 的数据,可以跑通流程,但是存在两个问题:
@xiaohangguo @pppppM 佬们,看下这两个问题是为啥呀 |
ref_model 要不直接用 llm 的 config 重新 build ? loss 为 nan 可能要 @xiaohangguo 帮忙看下公式细节 |
好,今晚我切到这个分支复现一下,debug看看 |
可以 我试试改成 用 llm 的 config 重新 build |
写了个Mock 数据pytest来验证算法,目前测试结果,loss计算应该是没有问题。
下一步需要适配Class DPOdataset ,一条batch中格式保持 |
把item_fn 搞了一下,但感觉还是有问题,单个conversation,应该是可以的,不知道能否和原来的encode_fn 结合,对于整个数据集处理好,正常走packer。 |
|
太强了! |
@amulil 请问现在有DPO训练的模型指标对比吗?我想参考这个实现RLHF-V |
@KooSung 目前暂时没有,后面会参考 https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/README.md 提到的 zephyr-7b-dpo-qlora 模型来看指标对比。 |
@pppppM 佬,按你说的,初步想法是在 dataset 目录下实现
DPODataset
,在 model 目录下实现DPO
,其他 hook 暂时和 sft 一致的,不用修改,但是有一个疑问,DPO 里有 model 和 ref_model 两个 model,deepspeed 相关的部分用修改嘛?