Skip to content

Adds GaLore style projection wrappers to optax optimizers

Notifications You must be signed in to change notification settings

EleutherAI/optax-galore

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Optax-GaLore

Optax-GaLore is an implementation of the Gradient Low-Rank Projection (GaLore) algorithm for memory-efficient training of Large Language Models (LLMs). This project extends the Optax optimization library with GaLore functionality.

Features

  • Memory-efficient optimization for large-scale models
  • Compatible with existing Optax optimizers
  • Flexible projection specifications for different layer types
  • Support for convolutional layers and other multi-dimensional tensors

Installation

To install Optax-GaLore, clone this repository and install the required dependencies:

git clone https://github.com/EleutherAI/optax-galore.git
cd optax-galore
pip install -r requirements.txt

Usage

Here's a basic example of how to use Optax-GaLore in your project:

import jax
import optax
import optax_galore.optax_galore as og

# Define your model and loss function
# ...

# Create a GaLore optimizer
learning_rate = 0.001
rank = 64
subspace_change_freq = 1000

optimizer = og.galore(
    learning_rate=learning_rate,
    rank=rank,
    subspace_change_freq=subspace_change_freq
)

# Initialize optimizer state
opt_state = optimizer.init(params)

# Define the loss function
def loss_fn(params, batch):
    # Your model's loss calculation
    return loss

# Define the update function
@jax.jit
def update(params, opt_state, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss

# In your training loop:
for batch in data_loader:
    params, opt_state, loss = update(params, opt_state, batch)

This updated usage example demonstrates how to create a jitted update function that includes the loss calculation, gradient computation, optimizer update, and parameter update. Using a jitted update function enables the compiler to optimize out the unprojected gradients to save memory (probably).

Advanced Usage

For more control over the projection dimensions, you can use the dimension_pytree parameter:

import optax_galore.optax_galore as og
from optax_galore.optax_galore import ProjectionSpec

dimension_pytree = {
    'conv1': {'w': ProjectionSpec(2, 3), 'b': None},
    'conv2': {'w': ProjectionSpec(2, 3), 'b': None}
}

optimizer = optax_galore.galore(
    learning_rate=learning_rate,
    rank=rank,
    subspace_change_freq=subspace_change_freq,
    dimension_pytree=dimension_pytree
)

You can also wrap other Optax optimizers with GaLore:

base_optimizer = optax.adam(learning_rate=0.001)
galore_optimizer = optax_galore.galore_wrapper(
    base_optimizer,
    rank=64,
    subspace_change_freq=1000
)

Testing

To run the tests, use pytest:

pytest tests/

Contributing

Contributions to Optax-GaLore are welcome! Please feel free to submit a Pull Request.

About

Adds GaLore style projection wrappers to optax optimizers

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages