Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cuda backend for KLU #35

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ dev = [
# "meep",
]

cuda = [
"jax[cuda12]",
"cupy",
]

[tool.setuptools.packages.find]
where = ["."]
include = ["sax", "sax.nn", "sax.backends"]
Expand Down
25 changes: 25 additions & 0 deletions sax/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@
"better performance during circuit evaluation!"
)

try:
from .cuda import (
analyze_circuit_cuda,
analyze_instances_cuda,
evaluate_circuit_cuda,
)

circuit_backends["cuda"] = (
analyze_instances_cuda,
analyze_circuit_cuda,
evaluate_circuit_cuda,
)
circuit_backends["default"] = (
analyze_instances_cuda,
analyze_circuit_cuda,
evaluate_circuit_cuda,
)
except ImportError:
default_backend = "klu" if "klu" in circuit_backends else "fg"
circuit_backends["default"] = circuit_backends[default_backend]
warnings.warn(
"cupy not found. Please install cupy for "
"better performance during circuit evaluation!"
)


def analyze_instances(
instances: Dict[str, Component],
Expand Down
200 changes: 200 additions & 0 deletions sax/backends/cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
""" SAX CUDA Backend """

from __future__ import annotations

from typing import Any, Dict

import cupy as cp
import cupyx
import cupyx.scipy.sparse.linalg
import jax.numpy as jnp
from natsort import natsorted

from ..netlist import Component
from ..saxtypes import Model, SCoo, SDense, SType, scoo


def scoo_cupy(S):
Si, Sj, Sx, ports_map = scoo(S)
return cp.asarray(Si), cp.asarray(Sj), cp.asarray(Sx), ports_map


def solve_cuda(Ai, Aj, Ax, B):
"""
Custom solver using CuPy for sparse matrix solve.

Args:
Ai (array): Row indices of non-zero values in the sparse matrix.
Aj (array): Column indices of non-zero values in the sparse matrix.
Ax (array): Non-zero values of the sparse matrix.
B (array): Right-hand side matrix to solve for.

Returns: array: Solution matrix.
"""
results = []
for Ax_mat in Ax:
# Create sparse matrix in COO format
A_coo = cupyx.scipy.sparse.coo_matrix((Ax_mat, (Ai, Aj)))

# Convert to CSR format for solving
A_csr = A_coo.tocsr()

# Solve the linear system
results.append(cupyx.scipy.sparse.linalg.spsolve(A_csr, B))

return cp.asarray(results)


def coo_mul_vec(Si, Sj, Sx, x):
"""
COO matrix-vector multiplication using CuPy.

Args:
Si (array): Row indices of non-zero values in the sparse matrix.
Sj (array): Column indices of non-zero values in the sparse matrix.
Sx (array): Non-zero values of the sparse matrix.
x (array): Dense vector to multiply with the sparse matrix.

Returns:
array: Resulting vector from the multiplication.
"""
results = []
for Sx_mat, x_vec in zip(Sx, x):
# Create sparse matrix in COO format
S_coo = cupyx.scipy.sparse.coo_matrix((Sx_mat, (Si, Sj)))

# Perform the matrix-vector multiplication
results.append(S_coo.dot(x_vec))

return cp.asarray(results)


def analyze_instances_cuda(
instances: Dict[str, Component],
models: Dict[str, Model],
) -> Dict[str, SCoo]:
instances, instances_old = {}, instances
for k, v in instances_old.items():
if not isinstance(v, Component):
v = Component(**v)
instances[k] = v
model_names = set()
for i in instances.values():
if i.info and "model" in i.info and isinstance(i.info["model"], str):
model_names.add(str(i.info["model"]))
else:
model_names.add(str(i.component))
dummy_models = {k: scoo_cupy(models[k]()) for k in model_names}
dummy_instances = {}
for k, i in instances.items():
if i.info and "model" in i.info and isinstance(i.info["model"], str):
dummy_instances[k] = dummy_models[str(i.info["model"])]
else:
dummy_instances[k] = dummy_models[str(i.component)]
return dummy_instances


def analyze_circuit_cuda(
analyzed_instances: Dict[str, SCoo],
connections: Dict[str, str],
ports: Dict[str, str],
) -> Any:
connections = {**connections, **{v: k for k, v in connections.items()}}
inverse_ports = {v: k for k, v in ports.items()}
port_map = {k: i for i, k in enumerate(ports)}

idx, Si, Sj, instance_ports = 0, [], [], {}
for name, instance in analyzed_instances.items():
si, sj, _, ports_map = scoo_cupy(instance)
Si.append(cp.asarray(si + idx))
Sj.append(cp.asarray(sj + idx))
instance_ports.update({f"{name},{p}": i + idx for p, i in ports_map.items()})
idx += len(ports_map)

n_col = idx
n_rhs = len(port_map)

Si = cp.concatenate(Si, -1)
Sj = cp.concatenate(Sj, -1)

Cmap = {
int(instance_ports[k]): int(instance_ports[v]) for k, v in connections.items()
}
Ci = cp.array(list(Cmap.keys()), dtype=cp.int32)
Cj = cp.array(list(Cmap.values()), dtype=cp.int32)

Cextmap = {
int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items()
}
Cexti = cp.asarray(list(Cextmap.keys()))
Cextj = cp.asarray(list(Cextmap.values()))
Cext = cp.zeros((n_col, n_rhs), dtype=complex)
Cext[Cexti, Cextj] = 1.0

mask = Cj[None, :] == Si[:, None]
CSi = cp.broadcast_to(Ci[None, :], mask.shape)[mask]

mask = (Cj[:, None] == Si[None, :]).any(0)
CSj = Sj[mask]

Ii = Ij = cp.arange(n_col)
I_CSi = cp.concatenate([CSi, Ii], -1)
I_CSj = cp.concatenate([CSj, Ij], -1)
return (
n_col,
mask,
Si,
Sj,
Cext,
Cexti,
Cextj,
I_CSi,
I_CSj,
tuple((k, v[1]) for k, v in analyzed_instances.items()),
tuple(port_map),
)


def evaluate_circuit_cuda(analyzed: Any, instances: Dict[str, SType]) -> SDense:
(
n_col,
mask,
Si,
Sj,
Cext,
Cexti,
Cextj,
I_CSi,
I_CSj,
dummy_pms,
port_map,
) = analyzed

idx = 0
Sx = []
batch_shape = ()
for name, pm_ in dummy_pms:
_, _, sx, ports_map = scoo_cupy(instances[name])
Sx.append(sx)
if len(sx.shape[:-1]) > len(batch_shape):
batch_shape = sx.shape[:-1]
idx += len(ports_map)

Sx = cp.concatenate(
[cp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1
)
CSx = Sx[..., mask]
Ix = cp.ones((*batch_shape, n_col))
I_CSx = cp.concatenate([-CSx, Ix], -1)

Sx = Sx.reshape(-1, Sx.shape[-1]) # n_lhs x N
I_CSx = I_CSx.reshape(-1, I_CSx.shape[-1]) # n_lhs x M
inv_I_CS_Cext = solve_cuda(I_CSi, I_CSj, I_CSx, Cext)
S_inv_I_CS_Cext = coo_mul_vec(Si, Sj, Sx, inv_I_CS_Cext)

CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj]

_, n, _ = CextT_S_inv_I_CS_Cext.shape
S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n)

return jnp.asarray(S), {p: i for i, p in enumerate(port_map)}
118 changes: 118 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest

import sax
import jax.numpy as jnp


@pytest.mark.parametrize("backend", ["cuda", "default", "klu", "fg"])
def test_backend(backend):
instances = {
"lft": {"component": "coupler"},
"top": {"component": "wg"},
"rgt": {"component": "mmi"},
}
connections = {"lft,out0": "rgt,in0", "lft,out1": "top,in0", "top,out0": "rgt,in1"}
ports = {"in0": "lft,in0", "out0": "rgt,out0"}
models = {
"wg": lambda: {
("in0", "out0"): -0.99477 - 0.10211j,
("out0", "in0"): -0.99477 - 0.10211j,
},
"mmi": lambda: {
("in0", "out0"): 0.7071067811865476,
("in0", "out1"): 0.7071067811865476j,
("in1", "out0"): 0.7071067811865476j,
("in1", "out1"): 0.7071067811865476,
("out0", "in0"): 0.7071067811865476,
("out1", "in0"): 0.7071067811865476j,
("out0", "in1"): 0.7071067811865476j,
("out1", "in1"): 0.7071067811865476,
},
"coupler": lambda: (
jnp.array(
[
[
5.19688622e-06 - 1.19777138e-05j,
6.30595625e-16 - 1.48061189e-17j,
-3.38542541e-01 - 6.15711852e-01j,
5.80662654e-03 - 1.11068866e-02j,
-3.38542542e-01 - 6.15711852e-01j,
-5.80662660e-03 + 1.11068866e-02j,
],
[
8.59445189e-16 - 8.29783014e-16j,
-2.08640825e-06 + 8.17315497e-06j,
2.03847666e-03 - 2.10649131e-03j,
5.30509661e-01 + 4.62504708e-01j,
-2.03847666e-03 + 2.10649129e-03j,
5.30509662e-01 + 4.62504708e-01j,
],
[
-3.38542541e-01 - 6.15711852e-01j,
2.03847660e-03 - 2.10649129e-03j,
7.60088070e-06 + 9.07340423e-07j,
2.79292426e-09 + 2.79093547e-07j,
5.07842364e-06 + 2.16385350e-06j,
-6.84244232e-08 - 5.00486817e-07j,
],
[
5.80662707e-03 - 1.11068869e-02j,
5.30509661e-01 + 4.62504708e-01j,
2.79291895e-09 + 2.79093540e-07j,
-4.55645798e-06 + 1.50570403e-06j,
6.84244128e-08 + 5.00486817e-07j,
-3.55812153e-06 + 4.59781091e-07j,
],
[
-3.38542541e-01 - 6.15711852e-01j,
-2.03847672e-03 + 2.10649131e-03j,
5.07842364e-06 + 2.16385349e-06j,
6.84244230e-08 + 5.00486816e-07j,
7.60088070e-06 + 9.07340425e-07j,
-2.79292467e-09 - 2.79093547e-07j,
],
[
-5.80662607e-03 + 1.11068863e-02j,
5.30509662e-01 + 4.62504708e-01j,
-6.84244296e-08 - 5.00486825e-07j,
-3.55812153e-06 + 4.59781093e-07j,
-2.79293217e-09 - 2.79093547e-07j,
-4.55645798e-06 + 1.50570403e-06j,
],
]
),
{"in0": 0, "out0": 2, "out1": 4},
),
}

(
analyze_instances,
analyze_circuit,
evaluate_circuit,
) = sax.backends.circuit_backends[backend]

analyzed_instances = analyze_instances(instances, models)
analyzed_circuit = analyze_circuit(analyzed_instances, connections, ports)
sdict_backend = sax.sdict(
evaluate_circuit(
analyzed_circuit,
{k: models[v["component"]]() for k, v in instances.items()},
)
)

analyzed_instances = sax.backends.analyze_instances_klu(instances, models)
analyzed_circuit = sax.backends.analyze_circuit_klu(
analyzed_instances, connections, ports
)
sdict_klu = sax.sdict(
sax.backends.evaluate_circuit_klu(
analyzed_circuit,
{k: models[v["component"]]() for k, v in instances.items()},
)
)

# Compare to klu backend as source of truth
for k in sdict_klu:
val_klu = sdict_klu[k]
val_backend = sdict_backend[k]
assert abs(val_klu - val_backend) < 1e-5
9 changes: 6 additions & 3 deletions tests/test_nbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,28 @@
NBS_DIR = os.path.join(TEST_DIR, "nbs")
NBS_FAIL_DIR = os.path.join(NBS_DIR, "failed")


def get_kernel():
kernel = os.environ.get("CONDA_DEFAULT_ENV", "base")
if kernel == "base":
kernel = "python3"
return kernel


shutil.rmtree(NBS_FAIL_DIR, ignore_errors=True)
os.mkdir(NBS_FAIL_DIR)


def _find_notebooks(*dir):
dir = os.path.abspath(os.path.join(TEST_DIR, *dir))
for root, _, files in os.walk(dir):
for file in files:
if ('checkpoint' in file) or (not file.endswith('.ipynb')):
if ("checkpoint" in file) or (not file.endswith(".ipynb")):
continue
yield os.path.join(root, file)

@pytest.mark.parametrize('path', sorted(_find_notebooks('nbs')))

@pytest.mark.parametrize("path", sorted(_find_notebooks("nbs")))
def test_nbs(path):
fn = os.path.basename(path)
nb = load_notebook_node(path)
Expand All @@ -38,4 +42,3 @@ def test_nbs(path):
output_path=None,
)
raise_for_execution_errors(nb, os.path.join(NBS_FAIL_DIR, fn))

Loading
Loading