Skip to content

Commit

Permalink
pass **kwargs to ApproximateGP.__call__ in DeepGPLayer (#2224)
Browse files Browse the repository at this point in the history
Co-authored-by: root <[email protected]>
Co-authored-by: Geoff Pleiss <[email protected]>
  • Loading branch information
3 people authored Jan 4, 2023
1 parent e80e5cd commit 3595191
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gpytorch/models/deep_gps/deep_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(self, inputs, are_samples=False, **kwargs):
inputs = inputs.expand(*inputs.shape[:-3], self.output_dims, *inputs.shape[-2:])

# Now run samples through the GP
output = ApproximateGP.__call__(self, inputs)
output = ApproximateGP.__call__(self, inputs, **kwargs)
if self.output_dims is not None:
mean = output.loc.transpose(-1, -2)
covar = BlockDiagLinearOperator(output.lazy_covariance_matrix, block_dim=-3)
Expand Down

0 comments on commit 3595191

Please sign in to comment.