From d82e2241c77ebaaf6114327f0b3106197756dc72 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 17 Aug 2022 12:29:35 +0200 Subject: [PATCH] Add additional test to `RandomLinkSplit` for `HeteroData` (#5221) * add test * changelog * changelog --- CHANGELOG.md | 2 +- test/transforms/test_random_link_split.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39e04dd4c985..520931e2417e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,7 +120,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present - Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828)) - Do not allow accessing edge types in `HeteroData` with two node types when there exists multiple relations between these types ([#4782](https://github.com/pyg-team/pytorch_geometric/pull/4782)) -- Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757)) +- Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757), [#5221](https://github.com/pyg-team/pytorch_geometric/pull/5221)) - Fixed a numerical instability in the `GeneralConv` and `neighbor_sample` tests ([#4754](https://github.com/pyg-team/pytorch_geometric/pull/4754)) - Fixed a bug in `HANConv` in which destination node features rather than source node features were propagated ([#4753](https://github.com/pyg-team/pytorch_geometric/pull/4753)) - Fixed versions of `checkout` and `setup-python` in CI ([#4751](https://github.com/pyg-team/pytorch_geometric/pull/4751)) diff --git a/test/transforms/test_random_link_split.py b/test/transforms/test_random_link_split.py index 032abe081af3..b4b049ac3c88 100644 --- a/test/transforms/test_random_link_split.py +++ b/test/transforms/test_random_link_split.py @@ -191,3 +191,8 @@ def test_random_link_split_on_undirected_hetero_data(): rev_edge_types=('p', 'p')) train_data, val_data, test_data = transform(data) assert train_data['p', 'p'].is_undirected() + + transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'), + rev_edge_types=('p', 'p')) + train_data, val_data, test_data = transform(data) + assert train_data['p', 'p'].is_undirected()