Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement model transform to remove minibatching operations from graph #7521

Open
aphc14 opened this issue Oct 2, 2024 · 4 comments
Open

Comments

@aphc14
Copy link

aphc14 commented Oct 2, 2024

Description

When using pm.Minibatch, the pm.sample_posterior_predictive returns predictions with the size of the minibatch instead of the full dataset size. To make predictions on the full dataset requires the previous trace to be passed into a new model with a similar setup. For complicated models, this would add several lines of code to create a new model that is almost identical to the previous model.

This enhancement would make it easier to perform posterior predictive checks when using minibatch.

relates to: https://discourse.pymc.io/t/minibatch-not-working/14061/10

Example scenario:

import numpy as np
import pymc as pm
import arviz as az
import pytensor.tensor as pt

# generate data
N = 10000
P = 3
rng = np.random.default_rng(88)
X_full = rng.uniform(2, 10, size=(N, 3))
beta = np.array([1.5, 0.2, -0.9])
y_full = np.matmul(X_full, beta) + rng.normal(0, 1, size=(N,))

Before:

# minibatch
X_mb, y_mb = pm.Minibatch(X_full, y_full, batch_size=100)

# original minibatch model
with pm.Model() as model_mb:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pm.Deterministic("mu", pt.matmul(X_mb, b))
    likelihood = pm.Normal(
        "likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
    )

    fit_mb = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata_mb = fit_mb.sample(500)

    pm.sample_posterior_predictive(idata_mb, extend_inferencedata=True)
    idata_mb.posterior = pm.compute_deterministics(
        idata_mb.posterior, merge_dataset=True
    )

# new but similar model to the original
with pm.Model() as model_preds:
    X = pm.Data("X", X_full)
    y = pm.Data("y", y_full)

    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pm.Deterministic("mu", pt.matmul(X, b))
    likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=y)
    
with model_preds:    
    pm.set_data({"X": X_full})
    ypreds = pm.sample_posterior_predictive(idata_mb)

print(f"Minibatch: {idata_mb.posterior_predictive.likelihood.sizes}")
print(f"Full Data: {ypreds.posterior_predictive.likelihood.sizes}")

# output
Minibatch: Frozen({'chain': 1, 'draw': 500, 'likelihood_dim_2': 100})
Full Data: Frozen({'chain': 1, 'draw': 500, 'likelihood_dim_2': 10000})
@aphc14 aphc14 added the bug label Oct 2, 2024
Copy link

welcome bot commented Oct 2, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 2, 2024

This is the correct behavior. Minibatch is defined as a stochastic slice of the random variable. You can define a model without minibatch and use the old trace to sample a full size dataset with posterior predictive

@aphc14 aphc14 changed the title BUG: Minibatch posterior predictive sampling returns incorrect data size ENH: Minibatch posterior predictive sampling to return predictions of original data Oct 3, 2024
@aphc14
Copy link
Author

aphc14 commented Oct 4, 2024

I have edited the original post to describe the issue as an enhancement/feature request rather than a bug. I'm not able to modify labels on my end though.

@ricardoV94 ricardoV94 changed the title ENH: Minibatch posterior predictive sampling to return predictions of original data Implement model transform to remove minibatching operations from graph Oct 4, 2024
@ricardoV94
Copy link
Member

I have edited the original post to describe the issue as an enhancement/feature request rather than a bug. I'm not able to modify labels on my end though.

Thanks I did that. Note that my suggestion of a model transform wouldn't do anything automatically. The API would be something like:

with pm.Model() as minibatch_m:
  ... # define model with minibatch
  idata = pm.sample()
  
with remove_minibatch(minibatch_m) as m:
  pm.sample_posterior_predictive(idata)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants