Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstring to token_weights() and related functions in sd1_clip #4640

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
148 changes: 148 additions & 0 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,27 @@ def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)

def parse_parentheses(string):
"""
Split a string based off top-level nested parentheses.

Parameters
----------
string : str
The string to be split into its top nested groups.

Returns
-------
result : list
A list of each element in string, split by top-level elements

Examples
--------
>>> string = "(foo)(bar)"
['(foo)', '(bar)']

>>> string = "(foo(bar)(test1))(test2(test3))"
['(foo(bar)(test1))', '(test2(test3))']
"""
result = []
current_item = ""
nesting_level = 0
Expand Down Expand Up @@ -260,13 +281,83 @@ def parse_parentheses(string):
return result

def token_weights(string, current_weight):
"""
Find the requested weight of a token, and multiply it by the current weight. For parentheses groupings with no set weight, multiply by 1.1.

Parameters
----------
string : str
The input of tokens to calculate requested weights for
current_weight : float
The current weight of all tokens

Returns
-------
out : list
A list of each token paired with the calculated weight

Examples
--------
>>> string = "(foo)"
>>> current_weight = 1.0
[('foo', 1.1)]

>>> string = "(foo)(bar)"
>>> current_weight = 1.0
[('foo', 1.1), ('bar', 1.1)]

>>> string = "(foo:2.0)"
>>> current_weight = 1.0
[('foo', 2.0)]

>>> string = "((foo))"
>>> current_weight = 1.0
[('foo', 1.21)]

>>> string = "((foo):1.1)"
>>> current_weight = 1.0
[('foo', 1.21)]

>>> string = "((foo:1.1))"
>>> current_weight = 1.0
[('foo', 1.1)]

>>> string = "(foo:0.0)"
>>> current_weight = 1.0
[('foo', 0.0)]

>>> string = "((foo:1.0):0.0)"
>>> current_weight = 1.0
[('foo', 1.0)]

>>> string = "foo ((((lol (cat:666) attack:100)))) baz"
>>> current_weight = 1.0
[('foo ', 1.0), ('lol ', 100.0), ('cat', 666.0), (' attack', 100.0), (' baz', 1.0)]

>>> string = "foo ((((lol (cat:666) attack):100))) baz"
>>> current_weight = 1.0
[('foo ', 1.0), ('lol ', 110.0), ('cat', 666.0), (' attack', 110.0), (' baz', 1.0)]

Notes
-----
See issue #4610 for more detail. One thing to note is that the default of 1.1 is multiplied
when there is no weight defined on the *interior* of the group instead of the exterior. This
behavior can be seen in the last two examples (thank you @jart for making these). In the first
example, the weight of 100 is defined inside the same parentheses grouping as both 'lol' and
'attack'. There is no parentheses between the defined weight and the tokens, so there is no
multiplication of the weight by 1.1. In the second example, the weight is outside the parentheses
grouping, so the weights inside the grouping are first given a modifier of 1.1, then given a
modifier of 100.

"""
a = parse_parentheses(string)
out = []
for x in a:
weight = current_weight
if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
x = x[1:-1]
xx = x.rfind(":")
# This line makes *all nestings* multiply the weight by 1.1
weight *= 1.1
if xx > 0:
try:
Expand All @@ -280,11 +371,55 @@ def token_weights(string, current_weight):
return out

def escape_important(text):
"""
Replace parentheses marked via backslashes with escape characters

Parameters
----------
text : str
The string to have its important parentheses replaced

Returns
-------
text : str
The input string, with important parentheses replaced

Examples
--------
>>> text = "\\(foo\\)(bar)"
"\0\2foo\0\1(bar)"

See Also
--------
unescape_important : Replace escape characters with parentheses
"""
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text

def unescape_important(text):
"""
Replace escape characters made via escape_important with parentheses

Parameters
----------
text : str
The string to have its escape characters replaced

Returns
-------
text : str
The input string, with escape characters replaced

Examples
--------
>>> text = "\0\2foo\0\1(bar)"
"(foo)(bar)"

See Also
--------
escape_important: makes strings with the escape characters this function uses
"""
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
Expand All @@ -309,6 +444,19 @@ def safe_load_embed_zip(embed_path):
return out

def expand_directory_list(directories):
"""
For all directories in a list, list all subdirectories beneath them

Parameters
----------
directories : list
The list of directories to search for subdirectories with

Returns
-------
dirs : list
A list of all subdirectories found underneath the given directories
"""
dirs = set()
for x in directories:
dirs.add(x)
Expand Down
Loading