Skip to content

Commit

Permalink
Support device_map='auto'
Browse files Browse the repository at this point in the history
  • Loading branch information
czczup committed Feb 4, 2024
1 parent 5aa95de commit 21baf4c
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def forward(
class InternVisionModel(PreTrainedModel):
main_input_name = 'pixel_values'
config_class = InternVisionConfig
_no_split_modules = ['InternVisionEncoderLayer']

def __init__(self, config: InternVisionConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class InternVLChatModel(PreTrainedModel):
config_class = InternVLChatConfig
main_input_name = 'pixel_values'
_no_split_modules = ['InternVisionEncoderLayer', 'LlamaDecoderLayer', 'LlamaForCausalLM']

def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def forward(
class InternVisionModel(PreTrainedModel):
main_input_name = 'pixel_values'
config_class = InternVisionConfig
_no_split_modules = ['InternVisionEncoderLayer']

def __init__(self, config: InternVisionConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class InternVLPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r'position_ids',
]
_no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
_no_split_modules = ['InternVisionEncoderLayer', 'LlamaDecoderLayer', 'LlamaForCausalLM']
_skip_keys_device_placement = 'past_key_values'
_keep_in_fp32_modules = ['wo']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class InternVLChatModel(PreTrainedModel):
config_class = InternVLChatConfig
main_input_name = 'pixel_values'
_no_split_modules = ['InternVisionEncoderLayer', 'LlamaDecoderLayer', 'LlamaForCausalLM']

def __init__(self, config: InternVLChatConfig, internvl=None, language_model=None):
super().__init__(config)
Expand Down

0 comments on commit 21baf4c

Please sign in to comment.