Skip to content

Commit

Permalink
Fix full test (#7007)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 23, 2023
1 parent 5f4a21c commit c78f358
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 62 deletions.
32 changes: 19 additions & 13 deletions test/nn/conv/test_gen_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ def test_gen_conv(aggr):
if is_full_test():
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, edge_index), out11)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11)
assert torch.allclose(jit(x1, edge_index, value), out12)
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12)
assert torch.allclose(jit(x1, edge_index), out11, atol=1e-6)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11,
atol=1e-6)
assert torch.allclose(jit(x1, edge_index, value), out12, atol=1e-6)
assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12,
atol=1e-6)

t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
Expand Down Expand Up @@ -71,10 +73,13 @@ def test_gen_conv(aggr):
if is_full_test():
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), edge_index), out21)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21)
assert torch.allclose(jit((x1, x2), edge_index, value), out22)
assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22)
assert torch.allclose(jit((x1, x2), edge_index), out21, atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, value), out22,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22,
atol=1e-6)

t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
Expand Down Expand Up @@ -120,13 +125,14 @@ def test_gen_conv(aggr):
if is_full_test():
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), edge_index, value), out1)
assert torch.allclose(jit((x1, x2), edge_index, value), out1,
atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),
out1)
out1, atol=1e-6)
assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),
out2)
out2, atol=1e-6)

t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj1.t()), out1)
assert torch.allclose(jit((x1, None), adj1.t()), out2)
assert torch.allclose(jit((x1, x2), adj1.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj1.t()), out2, atol=1e-6)
8 changes: 4 additions & 4 deletions test/nn/conv/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_graph_conv():

t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj1.t()), out21)
assert torch.allclose(jit((x1, x2), adj2.t()), out22)
assert torch.allclose(jit((x1, None), adj1.t()), out23)
assert torch.allclose(jit((x1, None), adj2.t()), out24)
assert torch.allclose(jit((x1, x2), adj1.t()), out21, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj2.t()), out22, atol=1e-6)
assert torch.allclose(jit((x1, None), adj1.t()), out23, atol=1e-6)
assert torch.allclose(jit((x1, None), adj2.t()), out24, atol=1e-6)
8 changes: 4 additions & 4 deletions test/nn/dense/test_dense_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def test_dense_gat_conv(heads, concat):

dense_out = dense_conv(x, adj, mask)

assert dense_out[1, 2].abs().sum() == 0
dense_out = dense_out.view(6, dense_out.size(-1))[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-4)

if is_full_test():
jit = torch.jit.script(dense_conv)
assert torch.allclose(jit(x, adj, mask), dense_out)

assert dense_out[1, 2].abs().sum() == 0
dense_out = dense_out.view(6, dense_out.size(-1))[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-4)


def test_dense_gat_conv_with_broadcasting():
batch_size, num_nodes, channels = 8, 3, 16
Expand Down
8 changes: 4 additions & 4 deletions test/nn/dense/test_dense_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def test_dense_gcn_conv():
dense_out = dense_conv(x, adj, mask)
assert dense_out.size() == (2, 3, channels)

assert dense_out[1, 2].abs().sum() == 0
dense_out = dense_out.view(6, channels)[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-4)

if is_full_test():
jit = torch.jit.script(dense_conv)
assert torch.allclose(jit(x, adj, mask), dense_out)

assert dense_out[1, 2].abs().sum() == 0
dense_out = dense_out.view(6, channels)[:-1]
assert torch.allclose(sparse_out, dense_out, atol=1e-4)


def test_dense_gcn_conv_with_broadcasting():
batch_size, num_nodes, channels = 8, 3, 16
Expand Down
2 changes: 0 additions & 2 deletions test/nn/dense/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def test_lazy_hetero_linear():

out = lin(x, type_vec)
assert out.size() == (3, 32)
assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)'


def test_hetero_dict_linear():
Expand Down Expand Up @@ -160,7 +159,6 @@ def test_lazy_hetero_dict_linear():
assert len(out_dict) == 2
assert out_dict['v'].size() == (3, 32)
assert out_dict['w'].size() == (2, 32)
assert str(lin) == "HeteroDictLinear({'v': 16, 'w': 8}, 32, bias=True)"


@withPackage('pyg_lib')
Expand Down
9 changes: 8 additions & 1 deletion test/utils/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ def test_to_torch_coo_tensor():
])
edge_attr = torch.randn(edge_index.size(1), 8)

adj = to_torch_coo_tensor(edge_index)
adj = to_torch_coo_tensor(edge_index, is_coalesced=False)
assert adj.is_coalesced()
assert adj.size() == (4, 4)
assert adj.layout == torch.sparse_coo
assert torch.allclose(adj.indices(), edge_index)

adj = to_torch_coo_tensor(edge_index, is_coalesced=True)
assert adj.is_coalesced()
assert adj.size() == (4, 4)
assert adj.layout == torch.sparse_coo
assert torch.allclose(adj.indices(), edge_index)
Expand Down
7 changes: 2 additions & 5 deletions torch_geometric/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def fn(model: Callable) -> Callable:
for key in prev_state.keys():
setattr(torch_geometric.typing, key, False)

# Temporarily adjust the logging level of `torch.compile`:
# Adjust the logging level of `torch.compile`:
# TODO (matthias) Disable only temporarily
prev_log_level = {
'torch._dynamo': logging.getLogger('torch._dynamo').level,
'torch._inductor': logging.getLogger('torch._inductor').level,
Expand All @@ -91,8 +92,4 @@ def fn(model: Callable) -> Callable:
# Finally, run `torch.compile` to create an optimized version:
out = torch.compile(model, *args, **kwargs)

# Restore the previous state:
for key, value in prev_log_level.items():
logging.getLogger(key).setLevel(value)

return out
12 changes: 3 additions & 9 deletions torch_geometric/nn/conv/cluster_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
spmm,
to_edge_index,
)
from torch_geometric.utils.sparse import get_sparse_diag, set_sparse_value
from torch_geometric.utils.sparse import set_sparse_value


class ClusterGCNConv(MessagePassing):
Expand Down Expand Up @@ -71,6 +71,7 @@ def reset_parameters(self):
self.lin_root.reset_parameters()

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
num_nodes = x.size(self.node_dim)
edge_weight: OptTensor = None

if isinstance(edge_index, SparseTensor):
Expand All @@ -94,13 +95,7 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
"supported in 'gcn_norm'")

if self.add_self_loops:
diag = get_sparse_diag(edge_index.size(0), 1.0,
edge_index.layout, edge_index.dtype,
edge_index.device)
edge_index = edge_index + diag

if edge_index.layout == torch.sparse_coo:
edge_index = edge_index.coalesce()
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

col_and_row, value = to_edge_index(edge_index)
col, row = col_and_row[0], col_and_row[1]
Expand All @@ -112,7 +107,6 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
edge_index = set_sparse_value(edge_index, edge_weight)

else:
num_nodes = x.size(self.node_dim)
if self.add_self_loops:
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
size=None)

alpha = self._alpha
assert alpha is not None
self._alpha = None

if self.concat:
Expand Down
13 changes: 4 additions & 9 deletions torch_geometric/nn/conv/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
SparseTensor,
torch_sparse,
)
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import add_self_loops as add_self_loops_fn
from torch_geometric.utils import (
add_remaining_self_loops,
is_torch_sparse_tensor,
scatter,
spmm,
to_edge_index,
)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils.sparse import get_sparse_diag, set_sparse_value
from torch_geometric.utils.sparse import set_sparse_value


@torch.jit._overload
Expand Down Expand Up @@ -70,14 +71,8 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
"supported in 'gcn_norm'")

adj_t = edge_index

if add_self_loops:
diag = get_sparse_diag(adj_t.size(0), fill_value, adj_t.layout,
adj_t.dtype, adj_t.device)
adj_t = adj_t + diag

if adj_t.layout == torch.sparse_coo:
adj_t = adj_t.coalesce()
adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)

edge_index, value = to_edge_index(adj_t)
col, row = edge_index[0], edge_index[1]
Expand Down
23 changes: 17 additions & 6 deletions torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
is_torch_sparse_tensor,
to_edge_index,
to_torch_coo_tensor,
to_torch_csr_tensor,
)


Expand Down Expand Up @@ -65,21 +66,25 @@ def remove_self_loops(
[3, 4]]))
"""
size: Optional[Tuple[int, int]] = None
layout: Optional[int] = None

is_sparse = is_torch_sparse_tensor(edge_index)
if is_sparse:
if is_torch_sparse_tensor(edge_index):
assert edge_attr is None
layout = edge_index.layout
size = (edge_index.size(0), edge_index.size(1))
edge_index, edge_attr = to_edge_index(edge_index)

mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]

if is_sparse:
if layout is not None:
assert edge_attr is not None
edge_attr = edge_attr[mask]
adj = to_torch_coo_tensor(edge_index, edge_attr, size=size)
return adj, None
if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :(
return to_torch_coo_tensor(edge_index, edge_attr, size, True), None
elif str(layout) == 'torch.sparse_csr':
return to_torch_csr_tensor(edge_index, edge_attr, size, True), None
raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")

if edge_attr is None:
return edge_index, None
Expand Down Expand Up @@ -220,10 +225,12 @@ def add_self_loops(
[1, 0, 0, 0, 1]]),
tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000]))
"""
layout: Optional[int] = None
is_sparse = is_torch_sparse_tensor(edge_index)

if is_sparse:
assert edge_attr is None
layout = edge_index.layout
size = (edge_index.size(0), edge_index.size(1))
edge_index, edge_attr = to_edge_index(edge_index)
elif isinstance(num_nodes, (tuple, list)):
Expand Down Expand Up @@ -261,7 +268,11 @@ def add_self_loops(

edge_index = torch.cat([edge_index, loop_index], dim=1)
if is_sparse:
return to_torch_coo_tensor(edge_index, edge_attr, size=size), None
if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :(
return to_torch_coo_tensor(edge_index, edge_attr, size), None
elif str(layout) == 'torch.sparse_csr':
return to_torch_csr_tensor(edge_index, edge_attr, size), None
raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')")
return edge_index, edge_attr


Expand Down
23 changes: 18 additions & 5 deletions torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def to_torch_coo_tensor(
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
size: Optional[Union[int, Tuple[int, int]]] = None,
is_coalesced: bool = False,
) -> Tensor:
r"""Converts a sparse adjacency matrix defined by edge indices and edge
attributes to a :class:`torch.sparse.Tensor` with layout
Expand All @@ -99,6 +100,9 @@ def to_torch_coo_tensor(
If given as an integer, will create a quadratic sparse matrix.
If set to :obj:`None`, will infer a quadratic sparse matrix based
on :obj:`edge_index.max() + 1`. (default: :obj:`None`)
is_coalesced (bool): If set to :obj:`True`, will assume that
:obj:`edge_index` is already coalesced and thus avoids expensive
computation. (default: :obj:`False`)
:rtype: :class:`torch.sparse.Tensor`
Expand All @@ -123,18 +127,20 @@ def to_torch_coo_tensor(

size = tuple(size) + edge_attr.size()[1:]

return torch.sparse_coo_tensor(
adj = torch.sparse_coo_tensor(
indices=edge_index,
values=edge_attr,
size=size,
device=edge_index.device,
).coalesce()
)
return adj._coalesced_(True) if is_coalesced else adj.coalesce()


def to_torch_csr_tensor(
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
size: Optional[Union[int, Tuple[int, int]]] = None,
is_coalesced: bool = False,
) -> Tensor:
r"""Converts a sparse adjacency matrix defined by edge indices and edge
attributes to a :class:`torch.sparse.Tensor` with layout
Expand All @@ -149,6 +155,9 @@ def to_torch_csr_tensor(
If given as an integer, will create a quadratic sparse matrix.
If set to :obj:`None`, will infer a quadratic sparse matrix based
on :obj:`edge_index.max() + 1`. (default: :obj:`None`)
is_coalesced (bool): If set to :obj:`True`, will assume that
:obj:`edge_index` is already coalesced and thus avoids expensive
computation. (default: :obj:`False`)
:rtype: :class:`torch.sparse.Tensor`
Expand All @@ -163,14 +172,15 @@ def to_torch_csr_tensor(
size=(4, 4), nnz=6, layout=torch.sparse_csr)
"""
adj = to_torch_coo_tensor(edge_index, edge_attr, size)
adj = to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
return adj.to_sparse_csr()


def to_torch_csc_tensor(
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
size: Optional[Union[int, Tuple[int, int]]] = None,
is_coalesced: bool = False,
) -> Tensor:
r"""Converts a sparse adjacency matrix defined by edge indices and edge
attributes to a :class:`torch.sparse.Tensor` with layout
Expand All @@ -185,6 +195,9 @@ def to_torch_csc_tensor(
If given as an integer, will create a quadratic sparse matrix.
If set to :obj:`None`, will infer a quadratic sparse matrix based
on :obj:`edge_index.max() + 1`. (default: :obj:`None`)
is_coalesced (bool): If set to :obj:`True`, will assume that
:obj:`edge_index` is already coalesced and thus avoids expensive
computation. (default: :obj:`False`)
:rtype: :class:`torch.sparse.Tensor`
Expand All @@ -199,7 +212,7 @@ def to_torch_csc_tensor(
size=(4, 4), nnz=6, layout=torch.sparse_csc)
"""
adj = to_torch_coo_tensor(edge_index, edge_attr, size)
adj = to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)
return adj.to_sparse_csc()


Expand Down Expand Up @@ -241,7 +254,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
row = adj.row_indices().detach()
return torch.stack([row, col], dim=0).long(), adj.values()

raise ValueError(f"Expected sparse tensor layout (got '{adj.layout}')")
raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')")


# Helper functions ############################################################
Expand Down

0 comments on commit c78f358

Please sign in to comment.