diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 90cfc05d36..760feeaaa4 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -228,33 +228,13 @@ def quantize_model( dataset = prepare_dataset(dataset, pad_token_id=pad_token_id, batch_size=batch_size) # Step 2: get the input of the 1st block + # To do that, we need to put the module preceding the first block on the same device as the first bloc. + # Then we run the model and it will stop at the first bloc as we added a prehook that raise an Exception after storing the inputs. + layer_inputs = [] layer_outputs = [] layer_input_kwargs = [] - class Catcher(nn.Module): - """hijack layer's forward pass to cache data""" - - def __init__(self, m): - super().__init__() - self.module = m - - def forward(self, input=None, **kwargs): - # specific to transformers - # some models use all key-value arguments in forward pass call - if input is None: - if "hidden_states" in kwargs: - input = kwargs["hidden_states"] - else: - raise ValueError("No input value found in the foward pass") - layer_inputs.append(input) - other_kwargs = {} - for k, v in kwargs.items(): # make sure other arguments also be captured - if k not in ["hidden_states"]: - other_kwargs[k] = v - layer_input_kwargs.append(other_kwargs) - raise ValueError - # get block_name if self.block_name_to_quantize is None: self.block_name_to_quantize = get_block_name_with_pattern(model) @@ -275,7 +255,23 @@ def forward(self, input=None, **kwargs): # get inputs by running self.module_name_preceding_first_block + first block on gpu blocks[0] = blocks[0].to(0) - blocks[0] = Catcher(blocks[0]) + def store_input_hook(_, input, *args): + kwargs = args[0] + input = input[0] + if input is None: + if "hidden_states" in kwargs: + input = kwargs["hidden_states"] + else: + raise ValueError("No input value found in the foward pass") + layer_inputs.append(input) + other_kwargs = {} + for k, v in kwargs.items(): # make sure other arguments also be captured + if k not in ["hidden_states"]: + other_kwargs[k] = v + layer_input_kwargs.append(other_kwargs) + raise ValueError + + handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu @@ -284,7 +280,8 @@ def forward(self, input=None, **kwargs): model(**data) except ValueError: pass - blocks[0] = blocks[0].module + + handle.remove() if not has_device_map: blocks[0].to(device) for module_name in self.module_name_preceding_first_block: