Skip to content

Implementation of plug in and play Attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens"

License

Notifications You must be signed in to change notification settings

DeepDream2045/LongNet

 
 

Repository files navigation

LongNet: Scaling Transformers to 1,000,000,000 Tokens

LongNetBanner

This is an open source implementation for the paper LongNet: Scaling Transformers to 1,000,000,000 Tokens by Jiayu Ding, Shuming Ma, Li Dong, Xingxing Zhang, Shaohan Huang, Wenhui Wang, Furu Wei. The LongNet is a Transformer variant designed to scale sequence length up to more than 1 billion tokens without sacrificing performance on shorter sequences.

News 📰

Installation

You can install LongNet using one of the following methods:

Method 1: Git Clone

  1. Clone the LongNet repository from GitHub:
git clone https://github.com/kyegomez/LongNet.git
  1. Navigate to the cloned directory:
cd LongNet
  1. Install the required dependencies:
pip install -r requirements.txt

Method 2: Pip Install

  • Note that pip install does not work as the flash-attn library cannot be compiled since it has custom CUDA Kernels and they need to be built manually.
  1. Install LongNet directly from PyPI using pip:
pip install LongNet

Please note that LongNet requires a compatible Python version (tested with Python 3.7).

Usage

Once you have installed LongNet, you can use the DilatedAttention class as follows:

import timeit
import torch
from LongNet.attention import DilatedAttention


#model config
d_model = 512
num_heads = 8
dilation_rate = 2
segment_size = 64

device = "cuda:0"
dtype=torch.float16

#input data
batch_size = 32
seq_len = 10000000


#create model and data
model = DilatedAttention(d_model, num_heads, dilation_rate, segment_size).to(device)
x = torch.randn((batch_size, seq_len, d_model), device=device, dtype=dtype)


#test forward pass
with torch.no_grad():
    output = model(x)
    print(f"Output shape: {output.shape}") # expected (batch_size, seq_Len)


#benchmark model
num_runs = 1000
start_time = timeit.default_timer()
for _ in range(num_runs):
    model(x)

elapsed_time = timeit.default_timer() - start_time
print(f"Average forward pass time: {elapsed_time / num_runs:.6f} seconds")

Introduction

Scaling sequence length has become a critical bottleneck in the era of large language models. However, existing methods struggle with either computational complexity or model expressivity, rendering the maximum sequence length restricted. In this paper, they introduce LongNet, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences. Specifically, they propose dilated attention, which expands the attentive field exponentially as the distance grows.

Features

LongNet has significant advantages:

  1. It has a linear computation complexity and a logarithm dependency between tokens.
  2. It can be served as a distributed trainer for extremely long sequences.
  3. Its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization.

Experiment results demonstrate that LongNet yields strong performance on both long-sequence modeling and general language tasks. Their work opens up new possibilities for modeling very long sequences, e.g., treating a whole corpus or even the entire Internet as a sequence.

Here's the updated usage and installation section with two methods: git clone or pip install LongNet:

Documentation

Training the Model

  • We're still working on the model configuation as closely in the paper as possible. There are 2 methods, one is accelerate and the other from LongNet import Train

Method 1

  • Git clone installation

  • Init your parameters accelerate config

  • Then accelerate launch LongNet/training.py

Method 2

  • Pip install method
from LongNet import Train

Train()

Share with Friends

Share LongNet with your friends and colleagues who might find it useful. Simply click on the links below to share on various platforms:

Thank you for sharing!

Share LongNet Repository

Roadmap

  • Recreate the sparsification mechanism

  • Recreate the gathering mechanism

  • Implement FlashAttention2.0

  • Implement Distributed Setup

  • create the all-gather operation in the backward that becomes a reduce-scatter operation

Citation

@inproceedings{ding2023longnet,
  title={LongNet: Scaling Transformers to 1,000,000,000 Tokens},
  author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Wei, Furu},
  booktitle={Proceedings of the 10th International Conference on Learning Representations},
  year={2023}
}

About

Implementation of plug in and play Attention from "LongNet: Scaling Transformers to 1,000,000,000 Tokens"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 83.6%
  • Python 16.4%