Skip to content

Commit

Permalink
allow te to use meta device with deferred init (#958)
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Feb 7, 2024
1 parent 60ab97f commit 8c7d6f4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,7 @@ def __init__(
fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
fc_kwargs['device'] = device
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model + 2 * self.kv_n_heads * self.head_dim,
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(
self.fc_kwargs: dict[str, Any] = {
'bias': bias,
}
if fc_type != 'te':
self.fc_kwargs['device'] = device

self.fc_kwargs['device'] = device

self.up_proj = FC_CLASS_REGISTRY[fc_type](
d_model,
Expand Down

0 comments on commit 8c7d6f4

Please sign in to comment.