Skip to content

Commit

Permalink
[iree-turbine] support simple mlp training for cuda by dynamo & ignor…
Browse files Browse the repository at this point in the history
…e torch.none when it appear in backward graph
  • Loading branch information
jun.pang committed May 10, 2024
1 parent 31d4378 commit c3a485e
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 21 deletions.
60 changes: 60 additions & 0 deletions examples/mlp_train/ut_mlp_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda'

# [ y = W_n * x_n + W_{n-1} * x_{n-1} + ... + W_1 * x_1 + b ]
torch.cuda.manual_seed_all(0)
x = torch.linspace(-1, 1, 100).reshape(-1)
y = 3 * x + 2 + torch.randn(x.size()) * 0.2

# cvt to tensor
x = torch.tensor(x, dtype=torch.float32).to(device)
y = torch.tensor(y, dtype=torch.float32).to(device)
print(x)
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.weight = nn.Parameter(torch.randn(1, requires_grad=True))
print(self.weight)
self.bias = nn.Parameter(torch.randn(1, requires_grad=True))

def forward(self, x : torch.Tensor):
out = x * self.weight + self.bias
return out


# model = SimpleMLP().to(device)
mod = SimpleMLP().to(device)

model = torch.compile(mod, backend='turbine_cpu')

learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_func = nn.MSELoss()

epochs = 2000
for epoch in range(epochs):
y_pred = model(x)
# print(y_pred)

loss = loss_func(y_pred.to(device), y.to(device))

optimizer.zero_grad()
# loss = y_pred.sum()
# loss = loss.to(device)
loss.backward()

optimizer.step()

if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

predicted = model(x).detach().cpu().numpy()
plt.plot(x.cpu().numpy(), y.cpu().numpy(), 'ro', label='Original data')
plt.plot(x.cpu().numpy(), predicted, label='Fitted line')
plt.legend()
plt.savefig('fitting_result.png')
plt.close()
59 changes: 51 additions & 8 deletions shark_turbine/dynamo/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import functools
import sys
import os

from ...runtime.device import (
DeviceState,
Expand All @@ -16,6 +17,7 @@
)

from iree.compiler.api import (
_initializeGlobalCL,
Invocation,
Session,
Source,
Expand All @@ -38,11 +40,31 @@
import torch
from torch._dynamo.backends.common import aot_autograd
from ..passes import turbine_cpu_pass_pipeline
from typing import Any, List
from functorch.compile import min_cut_rematerialization_partition

DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",)
DEFAULT_COMPILER_FLAGS = (
"--iree-input-type=torch",
)

global_cl_options = []
if os.getenv("mlir_print_ir_after_all") is not None:
global_cl_options.append("--mlir-print-ir-after-all")
global_cl_options.append("--mlir-print-ir-after-change")

if os.getenv("mlir_print_ir_before_all") is not None:
global_cl_options.append("--mlir-print-ir-before-all")


if len(global_cl_options) != 0:
_initializeGlobalCL("dynamo", *global_cl_options)

def device_from_inputs(example_inputs) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device

def _base_backend(gm: torch.fx.GraphModule, example_inputs):
def _base_backend(gm: torch.fx.GraphModule, example_inputs, is_fw=True):
# Set up the session, context and invocation.
# Note that we do this on one in-memory module in a few phases:
# 1. Build it from the FX graph.
Expand All @@ -52,7 +74,18 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
# 4. Output to an mmap buffer.
session = Session()
session.set_flags(*DEFAULT_COMPILER_FLAGS)
session.set_flags("--iree-hal-target-backends=llvm-cpu")

device = device_from_inputs(example_inputs)


device_index = None
device_type = device.type
if device_type == "cpu":
session.set_flags("--iree-hal-target-backends=llvm-cpu")
elif device_type == "cuda":
device_index = device.index
session.set_flags("--iree-hal-target-backends=cuda")

context = session.context
importer = FxImporter(context=context)
module = importer.module
Expand All @@ -65,6 +98,8 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
gm = turbine_cpu_pass_pipeline(gm, example_inputs)

# Import phase.
print("before import graph")
print(gm.print_readable(), file=sys.stderr)
importer.import_graph_module(gm)
print(module, file=sys.stderr)
with context:
Expand All @@ -80,7 +115,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
inv.output_vm_bytecode(output)

# Set up for runtime.
device_state = _get_device_state()
device_state = _get_device_state(device_type, device_index)
# TODO: Switch to wrap_buffer once https://github.com/openxla/iree/issues/14926
# is fixed.
# vmfb_module = VmModule.wrap_buffer(
Expand All @@ -94,14 +129,22 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
)
output.close()

return SpecializedExecutable(vmfb_module, device_state)
return SpecializedExecutable(vmfb_module, device_state, importer.anticipated_return_value)

def _base_backend_fw(gm: torch.fx.GraphModule, example_inputs):
return _base_backend(gm, example_inputs, is_fw=True)

backend = aot_autograd(fw_compiler=_base_backend)
def _base_backend_bw(gm: torch.fx.GraphModule, example_inputs):
return _base_backend(gm, example_inputs, is_fw=False)

backend = aot_autograd(fw_compiler=_base_backend_fw, bw_compiler=_base_backend_bw, partition_fn=functools.partial(min_cut_rematerialization_partition, compiler="turbine_cpu"))

# IREE runtime globals. For the CPU right now, there is no device selection,
# so it is easy.
@functools.lru_cache(maxsize=None)
def _get_device_state() -> DeviceState:
return DeviceState(driver="local-task")
def _get_device_state(device_type, device_index) -> DeviceState:
if device_type == "cpu":
return DeviceState(driver="local-task")
elif device_type == "cuda":
return DeviceState(driver="cuda", enumerated_info={'device_id':device_index})

47 changes: 34 additions & 13 deletions shark_turbine/dynamo/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import List, Optional, Sequence, Union
from dataclasses import dataclass
import torch.nn as nn
from iree.runtime import (
asdevicearray,
create_hal_module,
Expand All @@ -31,7 +32,7 @@
)

from ..runtime.device import Device, DeviceState

from ..dynamo.tensor import dtype_to_element_type

@functools.lru_cache(maxsize=None)
def get_vm_instance() -> VmInstance:
Expand Down Expand Up @@ -64,12 +65,14 @@ class SpecializedExecutable:
"entry_function",
"user_module",
"vm_context",
"anticipated_return_value",
]

def __init__(
self,
user_module: VmModule,
device_state: DeviceState,
anticipated_return_value: list[bool],
entry_name: str = "main",
):
self.user_module = user_module
Expand All @@ -81,6 +84,7 @@ def __init__(
),
)
self.device_state = device_state
self.anticipated_return_value = anticipated_return_value
self.entry_function = self.user_module.lookup_function(entry_name)

def __call__(self, *inputs):
Expand All @@ -101,26 +105,43 @@ def _inputs_to_device(self, inputs: list, arg_list: VmVariantList):
# TODO: We are assuming the worst case here which is that we have unknown Torch
# tensors that we send to the CPU and make continguous. Ideally, we would have
# fast paths for our own backends and interop.
device = self.device_state.device
device_name = self.device_state.torch_device
for input in inputs:
input_cpu = input.cpu().contiguous()
# Since this is already a fallback case, just use the numpy array interop.
# It isn't great, but meh... fallback case.
device_array = asdevicearray(self.device_state.device, input_cpu)
arg_list.push_ref(device_array._buffer_view)

# input_cpu = input.cpu().contiguous()
# # Since this is already a fallback case, just use the numpy array interop.
# # It isn't great, but meh... fallback case.
# device_array = asdevicearray(self.device_state.device, input_cpu)
# arg_list.push_ref(device_array._buffer_view)
if not input.is_contiguous():
input = input.cpu().contiguous()

if input.device.type.startswith("cpu"):
if device_name.startswith("cuda"):
input = input.to("cuda")

if(isinstance(input, nn.Parameter)):
buffer_view = device.allocator.import_buffer(device, input.data, dtype_to_element_type(input.dtype))
else:
buffer_view = device.allocator.import_buffer(device, input, dtype_to_element_type(input.dtype))
arg_list.push_ref(buffer_view)

def _returns_to_user(self, ret_list: VmVariantList):
# TODO: This is also not good that we are moving back to the CPU like this.
# We should be returning a custom Tensor implementation which represents
# our device data and has synchronization hooks for accessing it.
device = self.device_state.device
num_returns = len(ret_list)
# num_returns = len(ret_list)
num_returns = len(self.anticipated_return_value)
user_returns = [None] * num_returns
for i in range(num_returns):
device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i))
device_array = DeviceArray(device, device_buffer_view)
host_array = device_array.to_host()
user_returns[i] = torch_from_numpy(host_array) # type: ignore
ret_list_idx = 0 # self.anticipated_return_value could have None type elements, so here use ret_list_idx

for i in range(num_returns):
if self.anticipated_return_value[i]:
device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(ret_list_idx))
ret_list_idx += 1
element_type = HalElementType(device_buffer_view.element_type)
user_returns[i] = device.allocator.export_buffer(device, device_buffer_view, element_type)
return user_returns


Expand Down

0 comments on commit c3a485e

Please sign in to comment.