From 5da5f4ac9a14aa0d3893b3b5f25525c323c06185 Mon Sep 17 00:00:00 2001 From: Jakob Date: Thu, 15 Feb 2024 13:37:31 +0100 Subject: [PATCH] add base transform class --- src/continuity/transforms/__init__.py | 5 +++ src/continuity/transforms/transform.py | 47 ++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 src/continuity/transforms/__init__.py create mode 100644 src/continuity/transforms/transform.py diff --git a/src/continuity/transforms/__init__.py b/src/continuity/transforms/__init__.py new file mode 100644 index 00000000..fbc75b71 --- /dev/null +++ b/src/continuity/transforms/__init__.py @@ -0,0 +1,5 @@ +from transform import Transform + +__all__ = [ + "Transform", +] diff --git a/src/continuity/transforms/transform.py b/src/continuity/transforms/transform.py new file mode 100644 index 00000000..a48bbccf --- /dev/null +++ b/src/continuity/transforms/transform.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import warnings + +from abc import ABC, abstractmethod + + +class Transform(nn.Module, ABC): + """Abstract base class for transformations of tensors. + + Transformations are applied to tensors to improve model performance, enhance generalization, handle varied input + sizes, facilitate specific features, reduce overfitting, improve computational efficiency or many other reasons. + This class takes some tensor and transforms it into some other tensor. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @abstractmethod + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """Applies the transformation. + + Args: + tensor: Tensor that should be transformed. + + Returns: + Transformed tensor. + """ + + def backward(self, tensor: torch.Tensor) -> torch.Tensor: + """Applies the inverse transformation (given the transformation is bijective). + + When the transformation is not bijective (one-to-one correspondence of data) the inverse/backward transformation + is not applied. Instead, a warning is raised. + + Args: + tensor: Transformed tensor. + + Returns: + Tensor with the transformation undone (given it is possible). + """ + warnings.warn( + f"Backward pass for transformation {self.__class__.__name__} not implement! " + f"Backward pass is performed as identity!", + stacklevel=2, + ) + return tensor