diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 03d18f7c3f3f9..2a9a966628ad8 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -12,6 +12,7 @@ import torch import torch._export import torch._inductor +import torch._inductor.config import torch.nn as nn from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters @@ -1313,14 +1314,19 @@ def fn(a, b, alpha=1.0): with self.assertRaises(RuntimeError): torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0}) - so_path = torch._export.aot_compile( - torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False - ) - kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) - res = kernel_runner.run([a, b]) - self.assertTrue(isinstance(res, list)) - self.assertTrue(len(res) == 1) - self.assertEqual(fn(a, b, alpha=2.0), res[0]) + for simdlen in [0, None]: + with torch._inductor.config.patch({"cpp.simdlen": simdlen}): + so_path = torch._export.aot_compile( + torch.ops.aten.add, + args=(a, b), + kwargs={"alpha": 2.0}, + same_signature=False, + ) + kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path) + res = kernel_runner.run([a, b]) + self.assertTrue(isinstance(res, list)) + self.assertTrue(len(res) == 1) + self.assertEqual(fn(a, b, alpha=2.0), res[0]) def test_buffer_mutation_2(self): class Model(torch.nn.Module): diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index b5aa1d1b8a61b..617a7ba7e2626 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1717,6 +1717,11 @@ def get_include_and_linking_paths( else: libs = ["omp"] if config.is_fbcode() else ["gomp"] + # For AOT mode, the produced library relies on torch cpu to set grad mode + # like aoti_torch_grad_mode_set_enabled + if aot_mode and sys.platform == "linux" and not config.is_fbcode(): + libs += ["torch", "torch_cpu"] + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 if not config.abi_compatible: libs += ["c10"]