Skip to content

Commit

Permalink
[dynamo] Add itertools.repeat via polyfill (pytorch#110953)
Browse files Browse the repository at this point in the history
Fixes pytorch#110286

Pull Request resolved: pytorch#110953
Approved by: https://github.com/ezyang
  • Loading branch information
jon-chuang authored and pytorchmergebot committed Oct 10, 2023
1 parent 02a02a2 commit 6e770c0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
18 changes: 18 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7291,6 +7291,24 @@ def fn(target):
res = opt_func(a)
self.assertIsInstance(res, torch.Tensor)

def test_itertools_repeat(self):
counters.clear()

def fn(x):
r = itertools.repeat(100.0, 5)
for i in r:
x += i
return x

x = torch.randn([2, 5])
eager = fn(x)

compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
compiled = compiled_fn(x)

self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)

def test_itertools_accumulate_symint_default_sum(self):
# https://github.com/pytorch/pytorch/issues/110287
counters.clear()
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/polyfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ def index(iterator, item, start=0, end=-1):
return i
# This will not run in dynamo
raise ValueError(f"{item} is not in {type(iterator)}")


def repeat(item, count):
for i in range(count):
yield item
13 changes: 12 additions & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch._C
import torch._numpy as tnp
from .. import config, variables
from .. import config, polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
from ..guards import GuardBuilder
Expand Down Expand Up @@ -914,6 +914,17 @@ def wraps(fn):
return variables.functions.FunctoolsPartialVariable(
fn, args=rest_args, keywords=kwargs, **options
)
elif self.value is itertools.repeat:
from .builder import SourcelessBuilder

if len(args) < 2:
# We cannot risk infinite generator being consumed to exhaustion by dynamo
# (i.e. infinite loop)
unimplemented("Infinite repeat is not supported")

return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
)
else:
try:
path = inspect.getfile(self.value)
Expand Down

0 comments on commit 6e770c0

Please sign in to comment.