Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to decode predicted texts from CausalLMOutputWithPast in the training code? #46

Open
erjiaxiao opened this issue Sep 10, 2024 · 1 comment

Comments

@erjiaxiao
Copy link

Hello @siddk, I loaded the "prism-dinosiglip-224px+7b" weights you uploaded on huggingface and tried to decode predicted texts from CausalLMOutputWithPast in the training code as follows:

                for train_idx, batch in enumerate(dataloader):
                    with torch.autocast(
                        "cuda",
                        dtype=self.mixed_precision_dtype,
                        enabled=self.enable_mixed_precision_training,
                    ):
                        output: CausalLMOutputWithPast = self.vlm(
                            input_ids=batch["input_ids"],
                            attention_mask=batch["attention_mask"],
                            pixel_values=batch["pixel_values"],
                            labels=batch["labels"],
                            multimodal_indices=batch["multimodal_indices"],
                        )
                        loss = output.loss

                    predicted_str = self.vlm.llm_backbone.tokenizer.decode(torch.argmax(output['logits'], dim=-1)[0], skip_special_tokens=True).strip()

I decode the logits of the first sequence of the batch but get a random response like "Unterscheidung in0 atrys d2 grassy3 dery02ro imagepet:4 yellow-,y grassiaans grass in p1 in,1es d and grasss Ves blackbl0 ,0,, d me we of black color,,ery ofch ofery garden plant grass flowers0 v,,,, v0eryies bird tur plantsclose grassy v". It's kind of weird.

@erjiaxiao
Copy link
Author

And I have loaded the complete weights in the training code as follows:

# load complete weights
model_state_dict = torch.load("models/prism-dinosiglip-224px+7b/checkpoints/latest-checkpoint.pt", map_location="cpu")["model"]
vlm.projector.load_state_dict(model_state_dict["projector"])
vlm.llm_backbone.load_state_dict(model_state_dict["llm_backbone"])

I wonder if I missed any settings specified for training if I would like to load complete weights in the training code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant