diff --git a/adv_control/control.py b/adv_control/control.py index 27e3d84..112bbe1 100644 --- a/adv_control/control.py +++ b/adv_control/control.py @@ -545,7 +545,7 @@ def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, mo has_temporal_res_block_key = True # ControlNet++ check elif "task_embedding" in key: - raise Exception("ControlNet++ model detected; must be loaded using the Load ControlNet++ Model nodes.") + pass if has_controlnet_key and has_motion_modules_key: controlnet_type = ControlWeightType.SPARSECTRL diff --git a/adv_control/control_plusplus.py b/adv_control/control_plusplus.py index 3e2d8c4..b4da295 100644 --- a/adv_control/control_plusplus.py +++ b/adv_control/control_plusplus.py @@ -164,7 +164,7 @@ def union_controlnet_merge(self, hint: list[Tensor], control_type, emb, context) controlnet_cond = self.input_hint_block(hint[indexes[idx][0]], emb, context) feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) if idx < indexes.shape[0]: - feat_seq += self.task_embedding[indexes[idx][0]] + feat_seq += self.task_embedding[indexes[idx][0]].to(dtype=feat_seq.dtype, device=feat_seq.device) inputs.append(feat_seq.unsqueeze(1)) condition_list.append(controlnet_cond)