Skip to content

Commit

Permalink
Formatting and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mwootendev committed Dec 20, 2023
1 parent 89b8154 commit 27e0c5c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/dynamicprompts/sampling_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def process_variable_assignment(
self,
command: VariableAssignmentCommand,
) -> Command:
if command.preserve and command.name in self.variables:
if command.preserve and command.name in self.variables:
return self.variables[command.name]
if command.immediate:
if isinstance(command.value, LiteralCommand):
Expand Down
4 changes: 2 additions & 2 deletions tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,12 @@ def test_alternative_wildcard_wrap(self, wildcard_wrap: str, template: str):

@pytest.mark.parametrize(
("immediate", "preserve"),
[
[
(False, False),
(False, True),
(True, False),
(True, True),
]
],
)
def test_variable_commands(self, immediate: bool, preserve: bool):
op = "?" if preserve else ""
Expand Down
59 changes: 33 additions & 26 deletions tests/samplers/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,34 +759,41 @@ def test_immediate_literal_variable(self, random_sampling_context: SamplingConte
assert str(next(random_sampling_context.generator_from_command(cmd))) == "foo"

@pytest.mark.parametrize(
"prompt, possible_results",
[
(
"${season=summer} ${temp=cold} ${location=north}__drink/beverage__",
("a glass of iced tea", "a glass of iced pop")
),
(
"${season=summer} ${temp=cold} ${location=south}__drink/winter/beverage__",
("a mug of hot coffee")
),
(
"${season=summer} ${temp=cold}__drink/winter/beverage__",
("a mug of hot tea")
),
(
"__drink/summer/beverage__",
("a glass of iced sweet tea", "a glass of iced soda")
),
(
"${location=north}__drink/summer/beverage__",
("a glass of iced tea", "a glass of iced pop")
)
]
"prompt, possible_results",
[
(
"${season=summer} ${temp=cold} ${location=north}__drink/beverage__",
("a glass of iced tea", "a glass of iced pop"),
),
(
"${season=summer} ${temp=cold} ${location=south}__drink/winter/beverage__",
("a mug of hot coffee"),
),
(
"${season=summer} ${temp=cold}__drink/winter/beverage__",
("a mug of hot tea"),
),
(
"__drink/summer/beverage__",
("a glass of iced sweet tea", "a glass of iced soda"),
),
(
"${location=north}__drink/summer/beverage__",
("a glass of iced tea", "a glass of iced pop"),
),
],
)
def test_preserve_variable(self, random_sampling_context: SamplingContext, prompt: str, possible_results: list[str]):
def test_preserve_variable(
self,
random_sampling_context: SamplingContext,
prompt: str,
possible_results: list[str],
):
cmd = parse(prompt)
resolved_value = str(next(random_sampling_context.generator_from_command(cmd))).strip()
assert resolved_value in possible_results
resolved_value = str(
next(random_sampling_context.generator_from_command(cmd)),
).strip()
assert resolved_value in possible_results

def test_unknown_variable(self, wildcard_manager: WildcardManager):
ctx1 = SamplingContext(
Expand Down
11 changes: 5 additions & 6 deletions tests/test_data/wildcards/drink.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ drink:
beverage:
- '${temp=hot} ${season=winter} ${location?=north} __drink/beverage__'

north:
north:
- tea
south:
- coffee
- coffee

summer:
beverage:
Expand All @@ -18,17 +18,16 @@ drink:
- pop
south:
- sweet tea
- soda
- soda

container:
container:
hot:
- mug
cold:
- glass

temp:
hot:
hot:
- hot
cold:
- iced

0 comments on commit 27e0c5c

Please sign in to comment.