Skip to content

Commit

Permalink
Feature: Masked Operator (#147)
Browse files Browse the repository at this point in the history
* add masked operator base class
---------

Co-authored-by: Samuel Burbulla <[email protected]>
  • Loading branch information
JakobEliasWagner and samuelburbulla authored Jul 31, 2024
1 parent 64eecee commit cb825d7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 0.2.0

- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
- Add `MaskedOperator` base class.

## 0.1.0

Expand Down
32 changes: 32 additions & 0 deletions src/continuiti/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,35 @@ def num_params(self) -> int:
def __str__(self):
"""Return string representation of the operator."""
return self.__class__.__name__


class MaskedOperator(Operator, ABC):
"""Masked operator base class.
A masked operator can apply masks during the forward pass to selectively use or ignore parts of the input. Masked
operators allow for different numbers of sensors in addition to the common property of being able to handle
varying numbers of evaluations.
"""

@abstractmethod
def forward(
self,
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
sensor_mask: Optional[torch.Tensor] = None,
eval_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass through the operator.
Args:
x: Sensor positions of shape (batch_size, x_dim, num_sensors...).
u: Input function values of shape (batch_size, u_dim, num_sensors...).
y: Evaluation coordinates of shape (batch_size, y_dim, num_evaluations...).
sensor_mask: Boolean mask for x and u of shape (batch_size, 1, num_sensors...).
eval_mask: Boolean mask for y of shape (batch_size, 1, num_evaluations...).
Returns:
Evaluations of the mapped function with shape (batch_size, v_dim, num_evaluations...).
"""

0 comments on commit cb825d7

Please sign in to comment.