Skip to content

Latest commit

 

History

History
86 lines (64 loc) · 4.75 KB

README.md

File metadata and controls

86 lines (64 loc) · 4.75 KB

Heterogeneous Graph Transformer (HGT)

Alternative reference Deep Graph Library (DGL) implementation

Heterogeneous Graph Transformer is a graph neural network architecture that can deal with large-scale heterogeneous and dynamic graphs.

You can see our WWW 2020 paper Heterogeneous Graph Transformer for more details.

This implementation of HGT is based on Pytorch Geometric API

Overview

The most important files in this projects are as follow:

  • conv.py: The core of our model, implements the transformer-like heterogeneous graph convolutional layer.
  • model.py: The wrap of different model components.
  • data.py: The data interface and usage.
    • class Graph: The data structure of heterogeneous graph. Stores feature in Graph.node_feature as pandas.DataFrame; Stores adjacency matrix in Graph.edge_list as dictionay.
    • def sample_subgraph: The sampling algorithm for heterogeneous graph. Each iteration samples a fixed number of nodes per type. All the sampled nodes are within the region of already sampled nodes, with sampling probability as the square of relative degree.
  • train_*.py: The training and validation script for a specific downstream task.
    • def *_sample: The sampling function for a given task. Remember to mask out existing link within the graph to avoid information leakage.
    • def prepare_data: Conduct sampling in parallel with multiple processes, which can seamlessly coordinate with model training.

Setup

This implementation is based on pytorch_geometric. To run the code, you need the following dependencies:

You can simply run pip install -r requirements.txt to install all the necessary packages.

OAG DataSet

Our current experiments are conducted on Open Academic Graph (OAG). For easiness of usage, we split and preprocess the whole dataset into different granularity: all CS papers (8.1G), all ML papers (1.9G), all NN papers (0.6G) spanning from 1900-2020. You can download the preprocessed graph via this link.

If you want to directly process from raw data, you can download via this link. After downloading it, run preprocess_OAG.py to extract features and store them in our data structure.

You can also use our code to process other heterogeneous graph, as long as you load them into our data structure class Graph in data.py. Refer to preprocess_OAG.py for a demonstration.

Usage

Execute the following scripts to train on paper-field (L2) classification task using HGT:

python3 train_paper_field.py --data_dir PATH_OF_DATASET --model_dir PATH_OF_SAVED_MODEL --conv_name hgt

Conducting other two tasks are similar. There are some key options of this scrips:

  • conv_name: Choose corresponding model for training. By default we use HGT.
  • --sample_depth and --sample_width: The depth and width of sampled graph. If the model exceeds the GPU memory, can consider reduce their number; if one wants to train a deeper GNN model, consider adding these numbers.
  • --n_pool: The number of process to parallely conduct sampling. If one has a machine with large memory, can consider adding this number to reduce batch prepartion time.
  • --repeat: The number of time to reuse a sampled batch for training. If the training time is much smaller than sampling time, can consider adding this number.

The details of other optional hyperparameters can be found in train_*.py.

Citation

Please consider citing the following paper when using our code for your application.

@inproceedings{hgt,
  author    = {Ziniu Hu and
               Yuxiao Dong and
               Kuansan Wang and
               Yizhou Sun},
  title     = {Heterogeneous Graph Transformer},
  booktitle = {{WWW} '20: The Web Conference 2020, Taipei, Taiwan, April 20-24, 2020},
  pages     = {2704--2710},
  publisher = {{ACM} / {IW3C2}},
  year      = {2020},
  url       = {https://doi.org/10.1145/3366423.3380027},
  doi       = {10.1145/3366423.3380027},
  timestamp = {Wed, 06 May 2020 12:56:16 +0200},
  biburl    = {https://dblp.org/rec/conf/www/HuDWS20.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}