Skip to content

Commit

Permalink
Cleanup of APLP code
Browse files Browse the repository at this point in the history
This PR address comments from the last commit
regarding reducing the overhead of instantiating a pool
of works at every iteration.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Nov 14, 2024
1 parent 327b84b commit 461202c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
11 changes: 3 additions & 8 deletions iree/turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sympy
import math
from functools import partial
from ..utils import safe_subs
import multiprocessing as mp

T = index_symbol("$INITIATION_INTERVAL")
Expand Down Expand Up @@ -202,9 +203,7 @@ def all_pairs_longest_paths_symbolic(


def all_pairs_longest_paths(
graph: fx.Graph,
edges: list[Edge],
T: int,
graph: fx.Graph, edges: list[Edge], T: int, pool: mp.Pool
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
"""
For each node in the graph, compute the longest path to all other nodes.
Expand All @@ -228,14 +227,11 @@ def all_pairs_longest_paths(
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
for result in results:
D[result[0]] = result[1]
pool.close()
pool.join()

# Convert from index to node based representation.
G: dict[tuple[fx.Node, fx.Node], int] = {}
Expand All @@ -257,8 +253,7 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
if isinstance(D_static[key], sympy.Expr):
D_static[key] = D_static[key].subs(T, initiation_interval)
D_static[key] = safe_subs(D_static[key], [(T, initiation_interval)])
# Remove the negative infinity values and edges to self.
for k in list(D_static.keys()):
if math.isinf(D_static[k]) or k[0] == k[1]:
Expand Down
6 changes: 5 additions & 1 deletion iree/turbine/kernel/wave/scheduling/modulo_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable
import numpy as np
import math
import multiprocessing as mp

logger = get_logger("turbine.wave.modulo_scheduling")

Expand Down Expand Up @@ -110,10 +111,11 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
# TODO: Come up with a better heuristic on an upper bound for the initiation interval.
T_max_range = 3 * T0
success = False
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for T in range(T0, T0 + T_max_range):
logger.debug(f"Trying initiation interval: {T}.")
self.RT = np.zeros((T, len(self.resources)))
self.e_star = all_pairs_longest_paths(self.graph, self.edges, T)
self.e_star = all_pairs_longest_paths(self.graph, self.edges, T, pool)
logger.debug(f"All Pairs Longest Paths: {self.e_star}.")
self.schedule: dict[fx.Node, int] = {}
for _, scc in topological_sort(sccs).items():
Expand Down Expand Up @@ -148,6 +150,8 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
break
else:
raise Exception("Failed to schedule the graph.")
pool.close()
pool.join()

self._initiation_interval = T
return self.schedule, success
Expand Down
4 changes: 3 additions & 1 deletion tests/kernel/wave/scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
import torch.fx as fx
import numpy as np
import multiprocessing as mp
from iree.turbine.kernel.wave.visualization import visualize_graph
from iree.turbine.kernel.wave.scheduling.graph_utils import (
find_strongly_connected_components,
Expand Down Expand Up @@ -180,7 +181,8 @@ def testGraphUtils(self):
def testAPLP(self):
graph, weighted_edges, nodes = self.create_weighted_graph()
T = 4
D3 = all_pairs_longest_paths(graph, weighted_edges, T)
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
D3 = all_pairs_longest_paths(graph, weighted_edges, T, pool)
assert D3[(nodes["a"], nodes["b"])] == 2
assert D3[(nodes["a"], nodes["c"])] == 3
assert D3[(nodes["a"], nodes["d"])] == 4
Expand Down

0 comments on commit 461202c

Please sign in to comment.