Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions to Run Kernels #51

Open
PaulScemama opened this issue Dec 7, 2023 · 1 comment
Open

Functions to Run Kernels #51

PaulScemama opened this issue Dec 7, 2023 · 1 comment

Comments

@PaulScemama
Copy link

PaulScemama commented Dec 7, 2023

@junpenglao I'm taking a look at implementing the run_inference_loop from here. I'm running into a potential issue. It seems as though some inference algorithms require more than rng_key and state as inputs to their step function. Take for example sgld which requires a minibatch of data and a step size at each call to its .step.

I suspect this will also be the case too for the variational inference algorithms when they are in a more final state. In these situations, run_inference_loop cannot currently handle such cases.

Should I just leave these particular examples where this is the case alone? And then use the run_inference_loop wherever I can?

One potential solution to allow the incorporation of batches to be passed in during step is to modify run_inference_loop like so:

def run_inference_algorithm(
    rng_key,
    initial_state_or_position,
    inference_algorithm,
    batches,
    num_steps,
):  -> tuple[State, State, Info]:
    try:
        initial_state = inference_algorithm.init(initial_state_or_position)
    except TypeError:
        # We assume initial_state is already in the right format.
        initial_state = initial_state_or_position

    keys = split(rng_key, num_steps)

    @jax.jit
    def one_step(state, rng_key):
        batch = next(batches)
        state, info = inference_algorithm.step(rng_key, state, batch)
        return state, (state, info)

    final_state, (state_history, info_history) = lax.scan(one_step, initial_state, keys)
    return final_state, state_history, info_history

Where batches is any iterator (possibly a generator) over batches of data examples. However, if batches is a generator that uses any jax operations, then I have run into issues with scan (not exactly sure the reason), but if batches is a generator that uses (say numpy) then it does work.

An example of a numpy data generator:

def data_stream(seed, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    rng = np.random.RandomState(seed)
    num_batches = int(np.ceil(data_size / batch_size))
    while True:
        perm = rng.permutation(data_size)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield data[batch_idx]

batches = data_stream(...)

This also works with (say huggingface dataset) data loader. Something like

from datasets import Dataset
batches = Dataset.from_dict({"data":data}).with_format("jax").iter(batch_size=50)

I'm not sure this would be the preferred solution. I am also In any case, I'll think about it some more.

Thanks!

@junpenglao
Copy link
Member

Let's leave out SGLD for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants