Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

用原版的就可以,用这个版本就一直报错 IndexError: tuple index out of range #15

Open
sctm002 opened this issue Nov 3, 2022 · 1 comment

Comments

@sctm002
Copy link

sctm002 commented Nov 3, 2022

bsz = inputs[0].size(self.dim)
IndexError: tuple index out of range
原版是这样写的:
model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')])
按这个版本的介绍这样写:
model = BalancedDataParallel(1,model, dim=0).cuda()
就一直报错。
这个的说明内容也太少了吧。
不知道从何排错。

@sherlcok314159
Copy link

说明你将输入送进模型时,可能按照字典的形式,常见于transformers中,如:

inputs = {
    "input_ids": ...,
    "attention_mask": ...,
    "token_type_ids": ...,
}
outputs = model(**inputs)

再看下源码的处理:

def scatter(self, inputs, kwargs, device_ids):
    # 从inputs第一个输入中获取bsz
    bsz = inputs[0].size(self.dim)
    num_dev = len(self.device_ids)

所以当你上面输入过来的时候,Inputs就是个空的元组,肯定不work,可以将scatter获取bsz的代码改成我这个:

def scatter(self, inputs, kwargs, device_ids):
    if len(inputs) > 0:
        bsz = inputs[0].size(self.dim)
    elif kwargs:
        bsz = list(kwargs.values())[0].size(self.dim)
    else:
        raise ValueError("You must pass inputs to the model!")
    num_dev = len(self.device_ids)
    ...

@sherlcok314159 sherlcok314159 mentioned this issue Nov 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants