Skip to content

Commit

Permalink
replace catcher by prefoward hook
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jul 21, 2023
1 parent e404bde commit 89d18d6
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,33 +228,13 @@ def quantize_model(
dataset = prepare_dataset(dataset, pad_token_id=pad_token_id, batch_size=batch_size)

# Step 2: get the input of the 1st block
# To do that, we need to put the module preceding the first block on the same device as the first bloc.
# Then we run the model and it will stop at the first bloc as we added a prehook that raise an Exception after storing the inputs.

layer_inputs = []
layer_outputs = []
layer_input_kwargs = []

class Catcher(nn.Module):
"""hijack layer's forward pass to cache data"""

def __init__(self, m):
super().__init__()
self.module = m

def forward(self, input=None, **kwargs):
# specific to transformers
# some models use all key-value arguments in forward pass call
if input is None:
if "hidden_states" in kwargs:
input = kwargs["hidden_states"]
else:
raise ValueError("No input value found in the foward pass")
layer_inputs.append(input)
other_kwargs = {}
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states"]:
other_kwargs[k] = v
layer_input_kwargs.append(other_kwargs)
raise ValueError

# get block_name
if self.block_name_to_quantize is None:
self.block_name_to_quantize = get_block_name_with_pattern(model)
Expand All @@ -275,7 +255,23 @@ def forward(self, input=None, **kwargs):
# get inputs by running self.module_name_preceding_first_block + first block on gpu
blocks[0] = blocks[0].to(0)

blocks[0] = Catcher(blocks[0])
def store_input_hook(_, input, *args):
kwargs = args[0]
input = input[0]
if input is None:
if "hidden_states" in kwargs:
input = kwargs["hidden_states"]
else:
raise ValueError("No input value found in the foward pass")
layer_inputs.append(input)
other_kwargs = {}
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states"]:
other_kwargs[k] = v
layer_input_kwargs.append(other_kwargs)
raise ValueError

handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
Expand All @@ -284,7 +280,8 @@ def forward(self, input=None, **kwargs):
model(**data)
except ValueError:
pass
blocks[0] = blocks[0].module

handle.remove()
if not has_device_map:
blocks[0].to(device)
for module_name in self.module_name_preceding_first_block:
Expand Down

0 comments on commit 89d18d6

Please sign in to comment.