Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在使用FUDGE时出现报错:RuntimeError: Error(s) in loading state_dict for Model: #604

Open
Enermy opened this issue Oct 22, 2024 · 0 comments

Comments

@Enermy
Copy link

Enermy commented Oct 22, 2024

使用FUDGE(https://github.com/[yangkevin2/naacl-2021-fudge-controlled-generation](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation?tab=readme-ov-file)?tab=readme-ov-file)对glm4进行押韵诗歌生成的时候遇到一个问题:
image
完整代码如下:

def main(args):
    with open(args.dataset_info, 'rb') as rf:
        dataset_info = pickle.load(rf)
    MODEL_PATH = '/home/jiangsiyuan/glm-4-9b-chat'
    gpt_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    
    gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
    # gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
    gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN, add_special_tokens=False)[0]
    gpt_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True).to(args.device)
    gpt_model.eval()

    # 加载checkpoint和iambic模型
    checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word))
    iambic_model.load_state_dict(checkpoint['state_dict'])
    iambic_model = iambic_model.to(args.device)
    iambic_model.eval()
    
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.iambic_ckpt, checkpoint['epoch']))
        print('iambic model num params', num_params(iambic_model))

    with open(args.rhyme_info, 'rb') as rf:
        rhyme_info = pickle.load(rf)
    checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    rhyme_model.load_state_dict(checkpoint['state_dict'])
    rhyme_model = rhyme_model.to(args.device)
    rhyme_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.rhyme_ckpt, checkpoint['epoch']))
        print('rhyme model num params', num_params(rhyme_model))
    
    checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
    model_args = checkpoint['args']
    newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
    newline_model.load_state_dict(checkpoint['state_dict'])
    newline_model = newline_model.to(args.device)
    newline_model.eval()
    if args.verbose:
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.newline_ckpt, checkpoint['epoch']))
        print('iambic model num params', num_params(newline_model))
    with open(args.prefix_file, 'r') as rf:
        lines = rf.readlines()
    for line in tqdm(lines, total=len(lines)):
        couplet = predict_couplet(gpt_model, 
                gpt_tokenizer, 
                iambic_model, 
                rhyme_model,
                newline_model,
                [line], 
                dataset_info, 
                rhyme_info,
                args.precondition_topk,
                args.topk, 
                condition_lambda=args.condition_lambda,
                device=args.device)
        assert len(couplet) == 2
        print(couplet[1].strip().replace('\n', ''))
  if __name__=='__main__':
      parser = ArgumentParser()
      print(1111111)
      #DATA
      parser.add_argument('--iambic_ckpt', type=str,  default='ckpt/poetry/iambic_predictor/model.pth.tar')
      parser.add_argument('--rhyme_ckpt', type=str, default='ckpt/poetry/rhyme_predictor/model.pth.tar')
      parser.add_argument('--newline_ckpt', type=str,  default='ckpt/poetry/newline_predictor/model.pth.tar')
      parser.add_argument('--dataset_info', type=str,  help='saved dataset info', default='ckpt/poetry/rhyme_predictor/dataset_info')
      parser.add_argument('--rhyme_info', type=str,  help='saved rhyme info', default='ckpt/poetry/rhyme_predictor/rhyme_info')
      parser.add_argument('--model_string', type=str, default='/home/jiangsiyuan/glm-4-9b-chat')
      parser.add_argument('--prefix_file', type=str, default='poetry_data/couplet_prefixes.txt', help='file of prefix lines for couplets')
      parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
      parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
      parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
      parser.add_argument('--seed', type=int, default=1, help='random seed')
      parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
      parser.add_argument('--debug', action='store_true', default=False)
      args = parser.parse_args()
      random.seed(args.seed)
      np.random.seed(args.seed)
      torch.manual_seed(args.seed)
      main(args)

求大佬帮助!谢谢

@Enermy Enermy changed the title 在使用FUDGE时出现报错:RuntimeError: Error(s) in loading state_dict for Model: size mismatch for gpt_embed.weight: copying a param with shape torch.Size([50258, 300]) from checkpoint, the shape in current model is torch.Size([151344, 300]).RuntimeError: Error(s) in loading state_dict for Model: size mismatch for gpt_embed.weight: copying a param with shape torch.Size([50258, 300]) from checkpoint, the shape in current model is torch.Size([151344, 300]). 在使用FUDGE时出现报错:RuntimeError: Error(s) in loading state_dict for Model: Oct 22, 2024
@Enermy Enermy closed this as completed Oct 24, 2024
@Enermy Enermy reopened this Oct 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant