From 8c7d6f43f7fdc3ff53f140c8965153eb2db71b85 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 7 Feb 2024 12:16:32 -0800 Subject: [PATCH] allow te to use meta device with deferred init (#958) --- llmfoundry/models/layers/attention.py | 3 +-- llmfoundry/models/layers/ffn.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index e1120504d7..281f41753a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -522,8 +522,7 @@ def __init__( fc_kwargs: dict[str, Any] = { 'bias': bias, } - if fc_type != 'te': - fc_kwargs['device'] = device + fc_kwargs['device'] = device self.Wqkv = FC_CLASS_REGISTRY[fc_type]( self.d_model, self.d_model + 2 * self.kv_n_heads * self.head_dim, diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index fa3e109bf8..5e99e0a960 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -97,8 +97,8 @@ def __init__( self.fc_kwargs: dict[str, Any] = { 'bias': bias, } - if fc_type != 'te': - self.fc_kwargs['device'] = device + + self.fc_kwargs['device'] = device self.up_proj = FC_CLASS_REGISTRY[fc_type]( d_model,