-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/master' into group-nodes
- Loading branch information
Showing
37 changed files
with
979 additions
and
2,716 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import enum | ||
import torch | ||
import math | ||
import comfy.utils | ||
|
||
|
||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) | ||
return abs(a*b) // math.gcd(a, b) | ||
|
||
class CONDRegular: | ||
def __init__(self, cond): | ||
self.cond = cond | ||
|
||
def _copy_with(self, cond): | ||
return self.__class__(cond) | ||
|
||
def process_cond(self, batch_size, device, **kwargs): | ||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) | ||
|
||
def can_concat(self, other): | ||
if self.cond.shape != other.cond.shape: | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
conds = [self.cond] | ||
for x in others: | ||
conds.append(x.cond) | ||
return torch.cat(conds) | ||
|
||
class CONDNoiseShape(CONDRegular): | ||
def process_cond(self, batch_size, device, area, **kwargs): | ||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] | ||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) | ||
|
||
|
||
class CONDCrossAttn(CONDRegular): | ||
def can_concat(self, other): | ||
s1 = self.cond.shape | ||
s2 = other.cond.shape | ||
if s1 != s2: | ||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen | ||
return False | ||
|
||
mult_min = lcm(s1[1], s2[1]) | ||
diff = mult_min // min(s1[1], s2[1]) | ||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
conds = [self.cond] | ||
crossattn_max_len = self.cond.shape[1] | ||
for x in others: | ||
c = x.cond | ||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) | ||
conds.append(c) | ||
|
||
out = [] | ||
for c in conds: | ||
if c.shape[1] < crossattn_max_len: | ||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result | ||
out.append(c) | ||
return torch.cat(out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.