From c78f358bb664a15650c68113f7108be89aca98da Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 23 Mar 2023 07:05:11 +0100 Subject: [PATCH] Fix full test (#7007) --- test/nn/conv/test_gen_conv.py | 32 ++++++++++++--------- test/nn/conv/test_graph_conv.py | 8 +++--- test/nn/dense/test_dense_gat_conv.py | 8 +++--- test/nn/dense/test_dense_gcn_conv.py | 8 +++--- test/nn/dense/test_linear.py | 2 -- test/utils/test_sparse.py | 9 +++++- torch_geometric/compile.py | 7 ++--- torch_geometric/nn/conv/cluster_gcn_conv.py | 12 ++------ torch_geometric/nn/conv/gatv2_conv.py | 1 + torch_geometric/nn/conv/gcn_conv.py | 13 +++------ torch_geometric/utils/loop.py | 23 +++++++++++---- torch_geometric/utils/sparse.py | 23 +++++++++++---- 12 files changed, 84 insertions(+), 62 deletions(-) diff --git a/test/nn/conv/test_gen_conv.py b/test/nn/conv/test_gen_conv.py index 21aa14e6b686..a66de738c79c 100644 --- a/test/nn/conv/test_gen_conv.py +++ b/test/nn/conv/test_gen_conv.py @@ -40,10 +40,12 @@ def test_gen_conv(aggr): if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit(x1, edge_index), out11) - assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11) - assert torch.allclose(jit(x1, edge_index, value), out12) - assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12) + assert torch.allclose(jit(x1, edge_index), out11, atol=1e-6) + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out11, + atol=1e-6) + assert torch.allclose(jit(x1, edge_index, value), out12, atol=1e-6) + assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out12, + atol=1e-6) t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) @@ -71,10 +73,13 @@ def test_gen_conv(aggr): if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, x2), edge_index), out21) - assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21) - assert torch.allclose(jit((x1, x2), edge_index, value), out22) - assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22) + assert torch.allclose(jit((x1, x2), edge_index), out21, atol=1e-6) + assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out21, + atol=1e-6) + assert torch.allclose(jit((x1, x2), edge_index, value), out22, + atol=1e-6) + assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out22, + atol=1e-6) t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) @@ -120,13 +125,14 @@ def test_gen_conv(aggr): if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, x2), edge_index, value), out1) + assert torch.allclose(jit((x1, x2), edge_index, value), out1, + atol=1e-6) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), - out1) + out1, atol=1e-6) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), - out2) + out2, atol=1e-6) t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, x2), adj1.t()), out1) - assert torch.allclose(jit((x1, None), adj1.t()), out2) + assert torch.allclose(jit((x1, x2), adj1.t()), out1, atol=1e-6) + assert torch.allclose(jit((x1, None), adj1.t()), out2, atol=1e-6) diff --git a/test/nn/conv/test_graph_conv.py b/test/nn/conv/test_graph_conv.py index dd0de6439ce9..c80bb6d4de48 100644 --- a/test/nn/conv/test_graph_conv.py +++ b/test/nn/conv/test_graph_conv.py @@ -80,7 +80,7 @@ def test_graph_conv(): t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) - assert torch.allclose(jit((x1, x2), adj1.t()), out21) - assert torch.allclose(jit((x1, x2), adj2.t()), out22) - assert torch.allclose(jit((x1, None), adj1.t()), out23) - assert torch.allclose(jit((x1, None), adj2.t()), out24) + assert torch.allclose(jit((x1, x2), adj1.t()), out21, atol=1e-6) + assert torch.allclose(jit((x1, x2), adj2.t()), out22, atol=1e-6) + assert torch.allclose(jit((x1, None), adj1.t()), out23, atol=1e-6) + assert torch.allclose(jit((x1, None), adj2.t()), out24, atol=1e-6) diff --git a/test/nn/dense/test_dense_gat_conv.py b/test/nn/dense/test_dense_gat_conv.py index 72c4ccc4b153..94a5f7fd5068 100644 --- a/test/nn/dense/test_dense_gat_conv.py +++ b/test/nn/dense/test_dense_gat_conv.py @@ -41,14 +41,14 @@ def test_dense_gat_conv(heads, concat): dense_out = dense_conv(x, adj, mask) - assert dense_out[1, 2].abs().sum() == 0 - dense_out = dense_out.view(6, dense_out.size(-1))[:-1] - assert torch.allclose(sparse_out, dense_out, atol=1e-4) - if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) + assert dense_out[1, 2].abs().sum() == 0 + dense_out = dense_out.view(6, dense_out.size(-1))[:-1] + assert torch.allclose(sparse_out, dense_out, atol=1e-4) + def test_dense_gat_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 diff --git a/test/nn/dense/test_dense_gcn_conv.py b/test/nn/dense/test_dense_gcn_conv.py index 7d43ec0fcb6d..20237b80addf 100644 --- a/test/nn/dense/test_dense_gcn_conv.py +++ b/test/nn/dense/test_dense_gcn_conv.py @@ -39,14 +39,14 @@ def test_dense_gcn_conv(): dense_out = dense_conv(x, adj, mask) assert dense_out.size() == (2, 3, channels) - assert dense_out[1, 2].abs().sum() == 0 - dense_out = dense_out.view(6, channels)[:-1] - assert torch.allclose(sparse_out, dense_out, atol=1e-4) - if is_full_test(): jit = torch.jit.script(dense_conv) assert torch.allclose(jit(x, adj, mask), dense_out) + assert dense_out[1, 2].abs().sum() == 0 + dense_out = dense_out.view(6, channels)[:-1] + assert torch.allclose(sparse_out, dense_out, atol=1e-4) + def test_dense_gcn_conv_with_broadcasting(): batch_size, num_nodes, channels = 8, 3, 16 diff --git a/test/nn/dense/test_linear.py b/test/nn/dense/test_linear.py index d5e9185f152e..e3b04cd46590 100644 --- a/test/nn/dense/test_linear.py +++ b/test/nn/dense/test_linear.py @@ -125,7 +125,6 @@ def test_lazy_hetero_linear(): out = lin(x, type_vec) assert out.size() == (3, 32) - assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)' def test_hetero_dict_linear(): @@ -160,7 +159,6 @@ def test_lazy_hetero_dict_linear(): assert len(out_dict) == 2 assert out_dict['v'].size() == (3, 32) assert out_dict['w'].size() == (2, 32) - assert str(lin) == "HeteroDictLinear({'v': 16, 'w': 8}, 32, bias=True)" @withPackage('pyg_lib') diff --git a/test/utils/test_sparse.py b/test/utils/test_sparse.py index fd11dcdcf68e..7814846a6b9d 100644 --- a/test/utils/test_sparse.py +++ b/test/utils/test_sparse.py @@ -76,7 +76,14 @@ def test_to_torch_coo_tensor(): ]) edge_attr = torch.randn(edge_index.size(1), 8) - adj = to_torch_coo_tensor(edge_index) + adj = to_torch_coo_tensor(edge_index, is_coalesced=False) + assert adj.is_coalesced() + assert adj.size() == (4, 4) + assert adj.layout == torch.sparse_coo + assert torch.allclose(adj.indices(), edge_index) + + adj = to_torch_coo_tensor(edge_index, is_coalesced=True) + assert adj.is_coalesced() assert adj.size() == (4, 4) assert adj.layout == torch.sparse_coo assert torch.allclose(adj.indices(), edge_index) diff --git a/torch_geometric/compile.py b/torch_geometric/compile.py index e38f389dec61..6fcfc1966a79 100644 --- a/torch_geometric/compile.py +++ b/torch_geometric/compile.py @@ -76,7 +76,8 @@ def fn(model: Callable) -> Callable: for key in prev_state.keys(): setattr(torch_geometric.typing, key, False) - # Temporarily adjust the logging level of `torch.compile`: + # Adjust the logging level of `torch.compile`: + # TODO (matthias) Disable only temporarily prev_log_level = { 'torch._dynamo': logging.getLogger('torch._dynamo').level, 'torch._inductor': logging.getLogger('torch._inductor').level, @@ -91,8 +92,4 @@ def fn(model: Callable) -> Callable: # Finally, run `torch.compile` to create an optimized version: out = torch.compile(model, *args, **kwargs) - # Restore the previous state: - for key, value in prev_log_level.items(): - logging.getLogger(key).setLevel(value) - return out diff --git a/torch_geometric/nn/conv/cluster_gcn_conv.py b/torch_geometric/nn/conv/cluster_gcn_conv.py index 3a1571f70400..ed715c829430 100644 --- a/torch_geometric/nn/conv/cluster_gcn_conv.py +++ b/torch_geometric/nn/conv/cluster_gcn_conv.py @@ -12,7 +12,7 @@ spmm, to_edge_index, ) -from torch_geometric.utils.sparse import get_sparse_diag, set_sparse_value +from torch_geometric.utils.sparse import set_sparse_value class ClusterGCNConv(MessagePassing): @@ -71,6 +71,7 @@ def reset_parameters(self): self.lin_root.reset_parameters() def forward(self, x: Tensor, edge_index: Adj) -> Tensor: + num_nodes = x.size(self.node_dim) edge_weight: OptTensor = None if isinstance(edge_index, SparseTensor): @@ -94,13 +95,7 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor: "supported in 'gcn_norm'") if self.add_self_loops: - diag = get_sparse_diag(edge_index.size(0), 1.0, - edge_index.layout, edge_index.dtype, - edge_index.device) - edge_index = edge_index + diag - - if edge_index.layout == torch.sparse_coo: - edge_index = edge_index.coalesce() + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) col_and_row, value = to_edge_index(edge_index) col, row = col_and_row[0], col_and_row[1] @@ -112,7 +107,6 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor: edge_index = set_sparse_value(edge_index, edge_weight) else: - num_nodes = x.size(self.node_dim) if self.add_self_loops: edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index bb72856b55ec..2de2d95685b6 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -251,6 +251,7 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, size=None) alpha = self._alpha + assert alpha is not None self._alpha = None if self.concat: diff --git a/torch_geometric/nn/conv/gcn_conv.py b/torch_geometric/nn/conv/gcn_conv.py index 0b4d6a54ccef..cc61dfc0b05f 100644 --- a/torch_geometric/nn/conv/gcn_conv.py +++ b/torch_geometric/nn/conv/gcn_conv.py @@ -14,15 +14,16 @@ SparseTensor, torch_sparse, ) +from torch_geometric.utils import add_remaining_self_loops +from torch_geometric.utils import add_self_loops as add_self_loops_fn from torch_geometric.utils import ( - add_remaining_self_loops, is_torch_sparse_tensor, scatter, spmm, to_edge_index, ) from torch_geometric.utils.num_nodes import maybe_num_nodes -from torch_geometric.utils.sparse import get_sparse_diag, set_sparse_value +from torch_geometric.utils.sparse import set_sparse_value @torch.jit._overload @@ -70,14 +71,8 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, "supported in 'gcn_norm'") adj_t = edge_index - if add_self_loops: - diag = get_sparse_diag(adj_t.size(0), fill_value, adj_t.layout, - adj_t.dtype, adj_t.device) - adj_t = adj_t + diag - - if adj_t.layout == torch.sparse_coo: - adj_t = adj_t.coalesce() + adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes) edge_index, value = to_edge_index(adj_t) col, row = edge_index[0], edge_index[1] diff --git a/torch_geometric/utils/loop.py b/torch_geometric/utils/loop.py index d21283d76876..bb847e2fa9e2 100644 --- a/torch_geometric/utils/loop.py +++ b/torch_geometric/utils/loop.py @@ -10,6 +10,7 @@ is_torch_sparse_tensor, to_edge_index, to_torch_coo_tensor, + to_torch_csr_tensor, ) @@ -65,21 +66,25 @@ def remove_self_loops( [3, 4]])) """ size: Optional[Tuple[int, int]] = None + layout: Optional[int] = None - is_sparse = is_torch_sparse_tensor(edge_index) - if is_sparse: + if is_torch_sparse_tensor(edge_index): assert edge_attr is None + layout = edge_index.layout size = (edge_index.size(0), edge_index.size(1)) edge_index, edge_attr = to_edge_index(edge_index) mask = edge_index[0] != edge_index[1] edge_index = edge_index[:, mask] - if is_sparse: + if layout is not None: assert edge_attr is not None edge_attr = edge_attr[mask] - adj = to_torch_coo_tensor(edge_index, edge_attr, size=size) - return adj, None + if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :( + return to_torch_coo_tensor(edge_index, edge_attr, size, True), None + elif str(layout) == 'torch.sparse_csr': + return to_torch_csr_tensor(edge_index, edge_attr, size, True), None + raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')") if edge_attr is None: return edge_index, None @@ -220,10 +225,12 @@ def add_self_loops( [1, 0, 0, 0, 1]]), tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000])) """ + layout: Optional[int] = None is_sparse = is_torch_sparse_tensor(edge_index) if is_sparse: assert edge_attr is None + layout = edge_index.layout size = (edge_index.size(0), edge_index.size(1)) edge_index, edge_attr = to_edge_index(edge_index) elif isinstance(num_nodes, (tuple, list)): @@ -261,7 +268,11 @@ def add_self_loops( edge_index = torch.cat([edge_index, loop_index], dim=1) if is_sparse: - return to_torch_coo_tensor(edge_index, edge_attr, size=size), None + if str(layout) == 'torch.sparse_coo': # str(...) for TorchScript :( + return to_torch_coo_tensor(edge_index, edge_attr, size), None + elif str(layout) == 'torch.sparse_csr': + return to_torch_csr_tensor(edge_index, edge_attr, size), None + raise ValueError(f"Unexpected sparse tensor layout (got '{layout}')") return edge_index, edge_attr diff --git a/torch_geometric/utils/sparse.py b/torch_geometric/utils/sparse.py index 8de55e8a4534..e2c6034bb9c9 100644 --- a/torch_geometric/utils/sparse.py +++ b/torch_geometric/utils/sparse.py @@ -85,6 +85,7 @@ def to_torch_coo_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, + is_coalesced: bool = False, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with layout @@ -99,6 +100,9 @@ def to_torch_coo_tensor( If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) + is_coalesced (bool): If set to :obj:`True`, will assume that + :obj:`edge_index` is already coalesced and thus avoids expensive + computation. (default: :obj:`False`) :rtype: :class:`torch.sparse.Tensor` @@ -123,18 +127,20 @@ def to_torch_coo_tensor( size = tuple(size) + edge_attr.size()[1:] - return torch.sparse_coo_tensor( + adj = torch.sparse_coo_tensor( indices=edge_index, values=edge_attr, size=size, device=edge_index.device, - ).coalesce() + ) + return adj._coalesced_(True) if is_coalesced else adj.coalesce() def to_torch_csr_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, + is_coalesced: bool = False, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with layout @@ -149,6 +155,9 @@ def to_torch_csr_tensor( If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) + is_coalesced (bool): If set to :obj:`True`, will assume that + :obj:`edge_index` is already coalesced and thus avoids expensive + computation. (default: :obj:`False`) :rtype: :class:`torch.sparse.Tensor` @@ -163,7 +172,7 @@ def to_torch_csr_tensor( size=(4, 4), nnz=6, layout=torch.sparse_csr) """ - adj = to_torch_coo_tensor(edge_index, edge_attr, size) + adj = to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced) return adj.to_sparse_csr() @@ -171,6 +180,7 @@ def to_torch_csc_tensor( edge_index: Tensor, edge_attr: Optional[Tensor] = None, size: Optional[Union[int, Tuple[int, int]]] = None, + is_coalesced: bool = False, ) -> Tensor: r"""Converts a sparse adjacency matrix defined by edge indices and edge attributes to a :class:`torch.sparse.Tensor` with layout @@ -185,6 +195,9 @@ def to_torch_csc_tensor( If given as an integer, will create a quadratic sparse matrix. If set to :obj:`None`, will infer a quadratic sparse matrix based on :obj:`edge_index.max() + 1`. (default: :obj:`None`) + is_coalesced (bool): If set to :obj:`True`, will assume that + :obj:`edge_index` is already coalesced and thus avoids expensive + computation. (default: :obj:`False`) :rtype: :class:`torch.sparse.Tensor` @@ -199,7 +212,7 @@ def to_torch_csc_tensor( size=(4, 4), nnz=6, layout=torch.sparse_csc) """ - adj = to_torch_coo_tensor(edge_index, edge_attr, size) + adj = to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced) return adj.to_sparse_csc() @@ -241,7 +254,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]: row = adj.row_indices().detach() return torch.stack([row, col], dim=0).long(), adj.values() - raise ValueError(f"Expected sparse tensor layout (got '{adj.layout}')") + raise ValueError(f"Unexpected sparse tensor layout (got '{adj.layout}')") # Helper functions ############################################################