From d47d9cde477e1ab25821178968538eddeabc351a Mon Sep 17 00:00:00 2001 From: Stefano Roncelli <45285915+saiden89@users.noreply.github.com> Date: Wed, 22 Dec 2021 07:43:29 +0100 Subject: [PATCH] Remove size parameter for GATv2 and HEAT (#3744) * refactor heat_conv and test * refactor gatv2_conv and test --- test/nn/conv/test_gatv2_conv.py | 18 ++++++++---------- test/nn/conv/test_heat_conv.py | 4 ++-- torch_geometric/nn/conv/gatv2_conv.py | 16 +++++++--------- torch_geometric/nn/conv/heat_conv.py | 7 +++---- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index a6f6c4360919..ee53ff16a185 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -14,15 +14,14 @@ def test_gatv2_conv(): assert conv.__repr__() == 'GATv2Conv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) - assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) + assert torch.allclose(conv(x1, edge_index), out) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) - t = '(Tensor, Tensor, OptTensor, Size, NoneType) -> Tensor' + t = '(Tensor, Tensor, OptTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, edge_index), out) - assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) - t = '(Tensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor' + t = '(Tensor, SparseTensor, OptTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) @@ -39,7 +38,7 @@ def test_gatv2_conv(): assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 assert conv._alpha is None - t = ('(Tensor, Tensor, OptTensor, Size, bool) -> ' + t = ('(Tensor, Tensor, OptTensor, bool) -> ' 'Tuple[Tensor, Tuple[Tensor, Tensor]]') jit = torch.jit.script(conv.jittable(t)) result = jit(x1, edge_index, return_attention_weights=True) @@ -49,7 +48,7 @@ def test_gatv2_conv(): assert result[1][1].min() >= 0 and result[1][1].max() <= 1 assert conv._alpha is None - t = ('(Tensor, SparseTensor, OptTensor, Size, bool) -> ' + t = ('(Tensor, SparseTensor, OptTensor, bool) -> ' 'Tuple[Tensor, SparseTensor]') jit = torch.jit.script(conv.jittable(t)) result = jit(x1, adj.t(), return_attention_weights=True) @@ -60,15 +59,14 @@ def test_gatv2_conv(): adj = adj.sparse_resize((4, 2)) out1 = conv((x1, x2), edge_index) assert out1.size() == (2, 64) - assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1) + assert torch.allclose(conv((x1, x2), edge_index), out1) assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6) - t = '(OptPairTensor, Tensor, OptTensor, Size, NoneType) -> Tensor' + t = '(OptPairTensor, Tensor, OptTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), edge_index), out1) - assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1) - t = '(OptPairTensor, SparseTensor, OptTensor, Size, NoneType) -> Tensor' + t = '(OptPairTensor, SparseTensor, OptTensor, NoneType) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6) diff --git a/test/nn/conv/test_heat_conv.py b/test/nn/conv/test_heat_conv.py index 8322c1673a50..82f90c770fb2 100644 --- a/test/nn/conv/test_heat_conv.py +++ b/test/nn/conv/test_heat_conv.py @@ -20,7 +20,7 @@ def test_heat_conv(): assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out) - t = '(Tensor, Tensor, Tensor, Tensor, OptTensor, Size) -> Tensor' + t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index, node_type, edge_type, edge_attr), out) @@ -33,6 +33,6 @@ def test_heat_conv(): assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out) - t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor, Size) -> Tensor' + t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index 707093282c09..7a38b7773c41 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -1,5 +1,5 @@ from typing import Union, Tuple, Optional -from torch_geometric.typing import (Adj, Size, OptTensor, PairTensor) +from torch_geometric.typing import (Adj, OptTensor, PairTensor) import torch from torch import Tensor @@ -163,12 +163,12 @@ def reset_parameters(self): zeros(self.bias) def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, - edge_attr: OptTensor = None, size: Size = None, + edge_attr: OptTensor = None, return_attention_weights: bool = None): - # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, PairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa - # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa + # type: (Union[Tensor, PairTensor], Tensor, OptTensor, NoneType) -> Tensor # noqa + # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, NoneType) -> Tensor # noqa + # type: (Union[Tensor, PairTensor], Tensor, OptTensor, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa + # type: (Union[Tensor, PairTensor], SparseTensor, OptTensor, bool) -> Tuple[Tensor, SparseTensor] # noqa r""" Args: return_attention_weights (bool, optional): If set to :obj:`True`, @@ -202,8 +202,6 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, num_nodes = x_l.size(0) if x_r is not None: num_nodes = min(num_nodes, x_r.size(0)) - if size is not None: - num_nodes = min(size[0], size[1]) edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) edge_index, edge_attr = add_self_loops( @@ -220,7 +218,7 @@ def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, # propagate_type: (x: PairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, - size=size) + size=None) alpha = self._alpha self._alpha = None diff --git a/torch_geometric/nn/conv/heat_conv.py b/torch_geometric/nn/conv/heat_conv.py index c14ec314efdf..0faa4d1e7916 100644 --- a/torch_geometric/nn/conv/heat_conv.py +++ b/torch_geometric/nn/conv/heat_conv.py @@ -1,5 +1,5 @@ from typing import Optional -from torch_geometric.typing import Adj, Size, OptTensor +from torch_geometric.typing import Adj, OptTensor import torch from torch import Tensor @@ -89,8 +89,7 @@ def reset_parameters(self): self.lin.reset_parameters() def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor, - edge_type: Tensor, edge_attr: OptTensor = None, - size: Size = None) -> Tensor: + edge_type: Tensor, edge_attr: OptTensor = None) -> Tensor: """""" x = self.hetero_lin(x, node_type) @@ -99,7 +98,7 @@ def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor, # propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor) # noqa out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb, - edge_attr=edge_attr, size=size) + edge_attr=edge_attr, size=None) if self.concat: if self.root_weight: