Skip to content

Commit

Permalink
add base transform class
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner committed Feb 15, 2024
1 parent 9723d65 commit 5da5f4a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/continuity/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from transform import Transform

__all__ = [
"Transform",
]
47 changes: 47 additions & 0 deletions src/continuity/transforms/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import warnings

from abc import ABC, abstractmethod


class Transform(nn.Module, ABC):
"""Abstract base class for transformations of tensors.
Transformations are applied to tensors to improve model performance, enhance generalization, handle varied input
sizes, facilitate specific features, reduce overfitting, improve computational efficiency or many other reasons.
This class takes some tensor and transforms it into some other tensor.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@abstractmethod
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""Applies the transformation.
Args:
tensor: Tensor that should be transformed.
Returns:
Transformed tensor.
"""

def backward(self, tensor: torch.Tensor) -> torch.Tensor:
"""Applies the inverse transformation (given the transformation is bijective).
When the transformation is not bijective (one-to-one correspondence of data) the inverse/backward transformation
is not applied. Instead, a warning is raised.
Args:
tensor: Transformed tensor.
Returns:
Tensor with the transformation undone (given it is possible).
"""
warnings.warn(
f"Backward pass for transformation {self.__class__.__name__} not implement! "
f"Backward pass is performed as identity!",
stacklevel=2,
)
return tensor

0 comments on commit 5da5f4a

Please sign in to comment.