Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MINOR] Add functions and UnitTests for the Estimator LayeredGraph #1945

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 328 additions & 6 deletions src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* This estimator implements an approach based on a so-called layered graph,
Expand All @@ -43,7 +44,7 @@
*/
public class EstimatorLayeredGraph extends SparsityEstimator {

private static final int ROUNDS = 32;
private static final int ROUNDS = 512;
private final int _rounds;

public EstimatorLayeredGraph() {
Expand All @@ -57,21 +58,47 @@ public EstimatorLayeredGraph(int rounds) {
@Override
public DataCharacteristics estim(MMNode root) {
List<MatrixBlock> leafs = getMatrices(root, new ArrayList<>());
long nnz = new LayeredGraph(leafs, _rounds).estimateNnz();
List<OpCode> ops = getOps(root, new ArrayList<>());
List<LayeredGraph> LGs = new ArrayList<>();
LayeredGraph ret = traverse(root);
long nnz = ret.estimateNnz();
return root.setDataCharacteristics(new MatrixCharacteristics(
leafs.get(0).getNumRows(), leafs.get(leafs.size()-1).getNumColumns(), nnz));
ret._nodes.get(0).length, ret._nodes.get(ret._nodes.size() - 1).length, nnz));
}

public LayeredGraph traverse(MMNode node) {
if(node.getLeft() == null || node.getRight() == null) return null;
LayeredGraph retL = traverse(node.getLeft());
LayeredGraph retR = traverse(node.getRight());
LayeredGraph ret, left, right;

left = (node.getLeft().getData() == null)
? retL : new LayeredGraph(node.getLeft().getData(), _rounds);
right = (node.getRight().getData() == null)
? retR : new LayeredGraph(node.getRight().getData(), _rounds);

ret = estimInternal(left, right, node.getOp());

return ret;
}

@Override
public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
if( op == OpCode.MM )
return estim(m1, m2);
throw new NotImplementedException();
LayeredGraph lg1 = new LayeredGraph(m1, _rounds);
LayeredGraph lg2 = new LayeredGraph(m2, _rounds);
LayeredGraph output = estimInternal(lg1, lg2, op);
return OptimizerUtils.getSparsity(
output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz());
}

@Override
public double estim(MatrixBlock m, OpCode op) {
throw new NotImplementedException();
LayeredGraph lg1 = new LayeredGraph(m, _rounds);
LayeredGraph output = estimInternal(lg1, null, op);
return OptimizerUtils.getSparsity(
output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz());
}

@Override
Expand All @@ -80,6 +107,23 @@ public double estim(MatrixBlock m1, MatrixBlock m2) {
return OptimizerUtils.getSparsity(
m1.getNumRows(), m2.getNumColumns(), graph.estimateNnz());
}

private static LayeredGraph estimInternal(LayeredGraph lg1, LayeredGraph lg2, OpCode op) {
switch(op) {
case MM: return lg1.matMult(lg2);
case MULT: return lg1.and(lg2);
case PLUS: return lg1.or(lg2);
case RBIND: return lg1.rbind(lg2);
case CBIND: return lg1.cbind(lg2);
// case NEQZERO:
// case EQZERO:
case TRANS: return lg1.transpose();
case DIAG: return lg1.diag();
// case RESHAPE:
default:
throw new NotImplementedException();
}
}

private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
//NOTE: this extraction is only correct and efficient for chains, no DAGs
Expand All @@ -92,6 +136,18 @@ private List<MatrixBlock> getMatrices(MMNode node, List<MatrixBlock> leafs) {
return leafs;
}

private List<OpCode> getOps(MMNode node, List<OpCode> ops) {
//NOTE: this extraction is only correct and efficient for chains, no DAGs
if(node.isLeaf()) {
}
else {
getOps(node.getLeft(), ops);
getOps(node.getRight(), ops);
ops.add(node.getOp());
}
return ops;
}

public static class LayeredGraph {
private final List<Node[]> _nodes; //nodes partitioned by graph level
private final int _rounds; //length of propagated r-vectors
Expand All @@ -101,6 +157,12 @@ public LayeredGraph(List<MatrixBlock> chain, int r) {
_rounds = r;
chain.forEach(i -> buildNext(i));
}

public LayeredGraph(MatrixBlock m, int r) {
_nodes = new ArrayList<>();
_rounds = r;
buildNext(m);
}

public void buildNext(MatrixBlock mb) {
if( mb.isEmpty() )
Expand Down Expand Up @@ -168,7 +230,267 @@ private static double calcNNZ(double[] inpvec, int rounds) {
return (inpvec != null && inpvec.length > 0) ?
(rounds - 1) / Arrays.stream(inpvec).sum() : 0;
}


public LayeredGraph rbind(LayeredGraph lg) {
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);

Node[] rows = new Node[_nodes.get(0).length + lg._nodes.get(0).length];
Node[] columns = _nodes.get(1).clone();

System.arraycopy(_nodes.get(0), 0, rows, 0, _nodes.get(0).length);

for (int i = _nodes.get(0).length; i < rows.length; i++)
rows[i] = new Node();

for(int i = 0; i < lg._nodes.get(0).length; i++) {
for(int j = 0; j < columns.length; j++) {
List<Node> edges = lg._nodes.get(1)[j].getInput();
if(edges.contains(lg._nodes.get(0)[i])) {
columns[j].addInput(rows[i + _nodes.get(0).length]);
}
}
}
ret._nodes.add(rows);
ret._nodes.add(columns);
return ret;
}

public LayeredGraph cbind(LayeredGraph lg) {
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
int colLength = _nodes.get(1).length + lg._nodes.get(1).length;

Node[] rows = _nodes.get(0).clone();
Node[] columns = new Node[colLength];

System.arraycopy(_nodes.get(1), 0, columns, 0, _nodes.get(1).length);

for (int i = _nodes.get(1).length; i < columns.length; i++)
columns[i] = new Node();

for(int i = 0; i < rows.length; i++) {
for(int j = 0; j < lg._nodes.get(1).length; j++) {
List<Node> edges = lg._nodes.get(1)[j].getInput();
if(edges.contains(lg._nodes.get(0)[i])) {
columns[j + _nodes.get(1).length].addInput(rows[i]);
}
}
}
ret._nodes.add(rows);
ret._nodes.add(columns);
return ret;
}

public LayeredGraph matMult(LayeredGraph lg) {
List<MatrixBlock> m = Stream.concat(
this.toMatrixBlockList().stream(), lg.toMatrixBlockList().stream())
.collect(Collectors.toList());
return new LayeredGraph(m, _rounds);
}

public LayeredGraph or(LayeredGraph lg) {
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
Node[] rows = new Node[_nodes.get(0).length];
for (int i = 0; i < _nodes.get(0).length; i++)
rows[i] = new Node();
ret._nodes.add(rows);

for(int x = 0; x < _nodes.size() - 1; x++) {
int y = x + 1;
rows = ret._nodes.get(x);
Node[] columns = new Node[_nodes.get(y).length];
for (int i = 0; i < _nodes.get(y).length; i++)
columns[i] = new Node();

for(int i = 0; i < _nodes.get(x).length; i++) {
for(int j = 0; j < _nodes.get(y).length; j++) {
List<Node> edges1 = _nodes.get(y)[j].getInput();
List<Node> edges2 = lg._nodes.get(y)[j].getInput();
if(edges1.contains(_nodes.get(x)[i]) || edges2.contains(lg._nodes.get(x)[i]))
{
columns[j].addInput(rows[i]);
}
}
}
ret._nodes.add(columns);
}
return ret;
}

public LayeredGraph and(LayeredGraph lg) {
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
Node[] rows = new Node[_nodes.get(0).length];
for (int i = 0; i < _nodes.get(0).length; i++)
rows[i] = new Node();
ret._nodes.add(rows);

for(int x = 0; x < _nodes.size() - 1; x++) {
int y = x + 1;
rows = ret._nodes.get(x);
Node[] columns = new Node[_nodes.get(y).length];
for (int i = 0; i < _nodes.get(y).length; i++)
columns[i] = new Node();

for(int i = 0; i < _nodes.get(x).length; i++) {
for(int j = 0; j < _nodes.get(y).length; j++) {
List<Node> edges1 = _nodes.get(y)[j].getInput();
List<Node> edges2 = lg._nodes.get(y)[j].getInput();
if(edges1.contains(_nodes.get(x)[i]) && edges2.contains(lg._nodes.get(x)[i]))
{
columns[j].addInput(rows[i]);
}
}
}
ret._nodes.add(columns);
}
return ret;
}

public LayeredGraph transpose() {
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
Node[] rows = new Node[_nodes.get(_nodes.size() - 1).length];
for (int i = 0; i < rows.length; i++)
rows[i] = new Node();
ret._nodes.add(rows);

for(int x = _nodes.size() - 1; x > 0; x--) {
rows = ret._nodes.get(ret._nodes.size() - 1);
Node[] columnsOld = _nodes.get(x);
Node[] rowsOld = _nodes.get(x - 1);
Node[] columns = new Node[rowsOld.length];

for (int i = 0; i < rowsOld.length; i++)
columns[i] = new Node();

for(int i = 0; i < rowsOld.length; i++) {
for(int j = 0; j < columnsOld.length; j++) {
List<Node> edges = columnsOld[j].getInput();
if(edges.contains(rowsOld[i])) {
columns[i].addInput(rows[j]);
}
}
}
ret._nodes.add(columns);
}
return ret;
}

public LayeredGraph diag() {
LayeredGraph ret = new LayeredGraph(List.of(), _rounds);
Node[] rowsOld = _nodes.get(0);
Node[] columnsOld = _nodes.get(1);

if(_nodes.get(1).length == 1) {
Node[] rows = new Node[rowsOld.length];
Node[] columns = new Node[rowsOld.length];

for (int i = 0; i < rowsOld.length; i++)
rows[i] = new Node();
for (int i = 0; i < rowsOld.length; i++)
columns[i] = new Node();

List<Node> edges = columnsOld[0].getInput();
for(int i = 0; i < rowsOld.length; i++) {
for(int j = 0; j < rowsOld.length; j++) {
if(edges.contains(rowsOld[i]) && i == j) {
columns[j].addInput(rows[i]);
}
}
}
ret._nodes.add(rows);
ret._nodes.add(columns);
return ret;
}
else if(_nodes.get(0).length == 1){
Node[] rows = new Node[columnsOld.length];
Node[] columns = new Node[columnsOld.length];

for (int i = 0; i < columnsOld.length; i++)
rows[i] = new Node();
for (int i = 0; i < columnsOld.length; i++)
columns[i] = new Node();

for(int i = 0; i < columnsOld.length; i++) {
for(int j = 0; j < columnsOld.length; j++) {
List<Node> edges = columnsOld[j].getInput();
if(edges.contains(rowsOld[0]) && i == j) {
columns[j].addInput(rows[i]);
}
}
}
ret._nodes.add(rows);
ret._nodes.add(columns);
return ret;
}
else {
Node[] rows = new Node[rowsOld.length];
Node[] columns = new Node[1];
for (int i = 0; i < rowsOld.length; i++)
rows[i] = new Node();
for (int i = 0; i < 1; i++)
columns[i] = new Node();
for(int i = 0; i < rowsOld.length; i++) {
for(int j = 0; j < columnsOld.length; j++) {
List<Node> edges = columnsOld[j].getInput();
if(edges.contains(rowsOld[i]) && i == j) {
columns[0].addInput(rows[i]);
}
}
}
ret._nodes.add(rows);
ret._nodes.add(columns);
return ret;
}
}

public MatrixBlock toMatrixBlock() {
List<Double> a = new ArrayList<>();
int rows = _nodes.get(0).length;
int cols = _nodes.get(1).length;
for(int i = 0; i < rows * cols; i++) {
a.add(0.);
}
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
List<Node> edges = _nodes.get(1)[j].getInput();
if(edges.contains(_nodes.get(0)[i])) {
a.set(i * cols + j, 1. + a.get(i * cols + j));
}
else {
a.set(i * cols + j, 0.);
}
}
}
double[] arr = a.stream().mapToDouble(d -> d).toArray();
return new MatrixBlock(rows, cols, arr);
}

public List<MatrixBlock> toMatrixBlockList() {
List<MatrixBlock> m = new ArrayList<>();
for(int x = 0; x < _nodes.size() - 1; x++) {
int y = x + 1;
List<Double> a = new ArrayList<>();
int rows = _nodes.get(x).length;
int cols = _nodes.get(y).length;
for(int i = 0; i < rows * cols; i++) {
a.add(0.);
}
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
List<Node> edges = _nodes.get(y)[j].getInput();
if(edges.contains(_nodes.get(x)[i])) {
a.set(i * cols + j, 1. + a.get(i * cols + j));
}
else {
a.set(i * cols + j, 0.);
}
}
}
double[] arr = a.stream().mapToDouble(d -> d).toArray();
m.add(new MatrixBlock(rows, cols, arr));
}
return m;
}

private static class Node {
private List<Node> _input = new ArrayList<>();
private double[] _rvect;
Expand Down
Loading
Loading