diff --git a/src/dynamicprompts/sampling_context.py b/src/dynamicprompts/sampling_context.py index 6304a0e..a062464 100644 --- a/src/dynamicprompts/sampling_context.py +++ b/src/dynamicprompts/sampling_context.py @@ -162,7 +162,7 @@ def process_variable_assignment( self, command: VariableAssignmentCommand, ) -> Command: - if command.preserve and command.name in self.variables: + if command.preserve and command.name in self.variables: return self.variables[command.name] if command.immediate: if isinstance(command.value, LiteralCommand): diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index e4d6f98..ff684f9 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -392,12 +392,12 @@ def test_alternative_wildcard_wrap(self, wildcard_wrap: str, template: str): @pytest.mark.parametrize( ("immediate", "preserve"), - [ + [ (False, False), (False, True), (True, False), (True, True), - ] + ], ) def test_variable_commands(self, immediate: bool, preserve: bool): op = "?" if preserve else "" diff --git a/tests/samplers/test_common.py b/tests/samplers/test_common.py index 928b42b..cfa215b 100644 --- a/tests/samplers/test_common.py +++ b/tests/samplers/test_common.py @@ -759,34 +759,41 @@ def test_immediate_literal_variable(self, random_sampling_context: SamplingConte assert str(next(random_sampling_context.generator_from_command(cmd))) == "foo" @pytest.mark.parametrize( - "prompt, possible_results", - [ - ( - "${season=summer} ${temp=cold} ${location=north}__drink/beverage__", - ("a glass of iced tea", "a glass of iced pop") - ), - ( - "${season=summer} ${temp=cold} ${location=south}__drink/winter/beverage__", - ("a mug of hot coffee") - ), - ( - "${season=summer} ${temp=cold}__drink/winter/beverage__", - ("a mug of hot tea") - ), - ( - "__drink/summer/beverage__", - ("a glass of iced sweet tea", "a glass of iced soda") - ), - ( - "${location=north}__drink/summer/beverage__", - ("a glass of iced tea", "a glass of iced pop") - ) - ] + "prompt, possible_results", + [ + ( + "${season=summer} ${temp=cold} ${location=north}__drink/beverage__", + ("a glass of iced tea", "a glass of iced pop"), + ), + ( + "${season=summer} ${temp=cold} ${location=south}__drink/winter/beverage__", + ("a mug of hot coffee"), + ), + ( + "${season=summer} ${temp=cold}__drink/winter/beverage__", + ("a mug of hot tea"), + ), + ( + "__drink/summer/beverage__", + ("a glass of iced sweet tea", "a glass of iced soda"), + ), + ( + "${location=north}__drink/summer/beverage__", + ("a glass of iced tea", "a glass of iced pop"), + ), + ], ) - def test_preserve_variable(self, random_sampling_context: SamplingContext, prompt: str, possible_results: list[str]): + def test_preserve_variable( + self, + random_sampling_context: SamplingContext, + prompt: str, + possible_results: list[str], + ): cmd = parse(prompt) - resolved_value = str(next(random_sampling_context.generator_from_command(cmd))).strip() - assert resolved_value in possible_results + resolved_value = str( + next(random_sampling_context.generator_from_command(cmd)), + ).strip() + assert resolved_value in possible_results def test_unknown_variable(self, wildcard_manager: WildcardManager): ctx1 = SamplingContext( diff --git a/tests/test_data/wildcards/drink.yaml b/tests/test_data/wildcards/drink.yaml index 9c26e68..ef0d32b 100644 --- a/tests/test_data/wildcards/drink.yaml +++ b/tests/test_data/wildcards/drink.yaml @@ -5,10 +5,10 @@ drink: beverage: - '${temp=hot} ${season=winter} ${location?=north} __drink/beverage__' - north: + north: - tea south: - - coffee + - coffee summer: beverage: @@ -18,17 +18,16 @@ drink: - pop south: - sweet tea - - soda + - soda - container: + container: hot: - mug cold: - glass temp: - hot: + hot: - hot cold: - iced - \ No newline at end of file