From 6e770c0ddabcab80fd76306efdde671bdb099ed6 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Tue, 10 Oct 2023 20:40:28 +0000 Subject: [PATCH] [dynamo] Add `itertools.repeat` via polyfill (#110953) Fixes https://github.com/pytorch/pytorch/issues/110286 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110953 Approved by: https://github.com/ezyang --- test/dynamo/test_misc.py | 18 ++++++++++++++++++ torch/_dynamo/polyfill.py | 5 +++++ torch/_dynamo/variables/misc.py | 13 ++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 3c114c7dcb147..ce4539531791c 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7291,6 +7291,24 @@ def fn(target): res = opt_func(a) self.assertIsInstance(res, torch.Tensor) + def test_itertools_repeat(self): + counters.clear() + + def fn(x): + r = itertools.repeat(100.0, 5) + for i in r: + x += i + return x + + x = torch.randn([2, 5]) + eager = fn(x) + + compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) + compiled = compiled_fn(x) + + self.assertEqual(list(eager), list(compiled)) + self.assertEqual(len(counters["graph_break"]), 0) + def test_itertools_accumulate_symint_default_sum(self): # https://github.com/pytorch/pytorch/issues/110287 counters.clear() diff --git a/torch/_dynamo/polyfill.py b/torch/_dynamo/polyfill.py index c5844b6a4cd0b..25c96f9f31dfe 100644 --- a/torch/_dynamo/polyfill.py +++ b/torch/_dynamo/polyfill.py @@ -16,3 +16,8 @@ def index(iterator, item, start=0, end=-1): return i # This will not run in dynamo raise ValueError(f"{item} is not in {type(iterator)}") + + +def repeat(item, count): + for i in range(count): + yield item diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 4a4d64c3b6c45..32ca676a75317 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -9,7 +9,7 @@ import torch._C import torch._numpy as tnp -from .. import config, variables +from .. import config, polyfill, variables from ..bytecode_transformation import create_call_function, create_instruction from ..exc import unimplemented from ..guards import GuardBuilder @@ -914,6 +914,17 @@ def wraps(fn): return variables.functions.FunctoolsPartialVariable( fn, args=rest_args, keywords=kwargs, **options ) + elif self.value is itertools.repeat: + from .builder import SourcelessBuilder + + if len(args) < 2: + # We cannot risk infinite generator being consumed to exhaustion by dynamo + # (i.e. infinite loop) + unimplemented("Infinite repeat is not supported") + + return tx.inline_user_function_return( + SourcelessBuilder()(tx, polyfill.repeat), args, kwargs + ) else: try: path = inspect.getfile(self.value)