Skip to content

ae

Classes:

Name Description
Autoencoder

Wrapper for PyTorch autoencoder models that uses the networks

AutoencoderInitialized

Represents an initialized autoencoder for anomaly detection and feature learning.

Autoencoder

Autoencoder(
    module: Type[Module],
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 0.001,
    is_feature_incremental: bool = False,
    device: str = "cpu",
    seed: int = 42,
    **kwargs
)

Bases: DeepEstimator, AnomalyDetector

Wrapper for PyTorch autoencoder models that uses the networks reconstruction error for scoring the anomalousness of a given example.

Parameters:

Name Type Description Default
module Type[Module]

Torch Module that builds the autoencoder to be wrapped. The Module should accept parameter n_features so that the returned model's input shape can be determined based on the number of features in the initial training example.

required
loss_fn Union[str, Callable]

Loss function to be used for training the wrapped model. Can be a loss function provided by torch.nn.functional or one of the following: 'mse', 'l1', 'cross_entropy', 'binary_crossentropy', 'smooth_l1', 'kl_div'.

'mse'
optimizer_fn Union[str, Callable]

Optimizer to be used for training the wrapped model. Can be an optimizer class provided by torch.optim or one of the following: "adam", "adam_w", "sgd", "rmsprop", "lbfgs".

'sgd'
lr float

Learning rate of the optimizer.

0.001
device str

Device to run the wrapped model on. Can be "cpu" or "cuda".

'cpu'
seed int

Random seed to be used for training the wrapped model.

42
**kwargs

Parameters to be passed to the torch.Module class aside from n_features.

{}

Examples:

>>> from deep_river.anomaly import Autoencoder
>>> from river import metrics
>>> from river.datasets import CreditCard
>>> from torch import nn
>>> import math
>>> from river.compose import Pipeline
>>> from river.preprocessing import MinMaxScaler
>>> dataset = CreditCard().take(5000)
>>> metric = metrics.RollingROCAUC(window_size=5000)
>>> class MyAutoEncoder(torch.nn.Module):
...     def __init__(self, n_features, latent_dim=3):
...         super(MyAutoEncoder, self).__init__()
...         self.linear1 = nn.Linear(n_features, latent_dim)
...         self.nonlin = torch.nn.LeakyReLU()
...         self.linear2 = nn.Linear(latent_dim, n_features)
...         self.sigmoid = nn.Sigmoid()
...
...     def forward(self, X, **kwargs):
...         X = self.linear1(X)
...         X = self.nonlin(X)
...         X = self.linear2(X)
...         return self.sigmoid(X)
>>> ae = Autoencoder(module=MyAutoEncoder, lr=0.005)
>>> scaler = MinMaxScaler()
>>> model = Pipeline(scaler, ae)
>>> for x, y in dataset:
...    score = model.score_one(x)
...    model.learn_one(x=x)
...    metric.update(y, score)
...
>>> print(f"Rolling ROCAUC: {metric.get():.4f}")
Rolling ROCAUC: 0.8901

Methods:

Name Description
clone

Clones the estimator.

draw

Draws the wrapped model.

initialize_module

Parameters

learn_many

Performs one step of training with a batch of examples.

learn_one

Performs one step of training with a single example.

score_many

Returns an anomaly score for the provided batch of examples in

score_one

Returns an anomaly score for the provided example in the form of

Source code in deep_river/anomaly/ae.py
def __init__(
    self,
    module: Type[torch.nn.Module],
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 1e-3,
    is_feature_incremental: bool = False,
    device: str = "cpu",
    seed: int = 42,
    **kwargs,
):
    warnings.warn(
        "This is deprecated and will be removed in future releases. "
        "Please instead use the AutoencoderInitialized class and "
        "initialize the module beforehand"
    )

    super().__init__(
        module=module,
        loss_fn=loss_fn,
        optimizer_fn=optimizer_fn,
        lr=lr,
        is_feature_incremental=is_feature_incremental,
        device=device,
        seed=seed,
        **kwargs,
    )
    self.is_class_incremental = is_feature_incremental

clone

clone(
    new_params: dict[Any, Any] | None = None,
    include_attributes=False,
)

Clones the estimator.

Parameters:

Name Type Description Default
new_params dict[Any, Any] | None

New parameters to be passed to the cloned estimator.

None
include_attributes

If True, the attributes of the estimator will be copied to the cloned estimator. This is useful when the estimator is a transformer and the attributes are the learned parameters.

False

Returns:

Type Description
DeepEstimator

The cloned estimator.

Source code in deep_river/base.py
def clone(
    self,
    new_params: dict[Any, Any] | None = None,
    include_attributes=False,
):
    """Clones the estimator.

    Parameters
    ----------
    new_params
        New parameters to be passed to the cloned estimator.
    include_attributes
        If True, the attributes of the estimator will be copied to the
        cloned estimator. This is useful when the estimator is a
        transformer and the attributes are the learned parameters.

    Returns
    -------
    DeepEstimator
        The cloned estimator.
    """
    new_params = new_params or {}
    new_params.update(self.kwargs)
    new_params.update(self._get_params())
    new_params.update({"module": self.module_cls})

    clone = self.__class__(**new_params)
    if include_attributes:
        clone.__dict__.update(self.__dict__)
    return clone

draw

draw() -> Digraph

Draws the wrapped model.

Source code in deep_river/base.py
def draw(self) -> Digraph:
    """Draws the wrapped model."""
    first_parameter = next(self.module.parameters())
    input_shape = first_parameter.size()
    y_pred = self.module(torch.rand(input_shape))
    return make_dot(y_pred.mean(), params=dict(self.module.named_parameters()))

initialize_module

initialize_module(x: dict | DataFrame, **kwargs)

Parameters:

Name Type Description Default
module

The instance or class or callable to be initialized, e.g. self.module.

required
kwargs dict

The keyword arguments to initialize the instance or class. Can be an empty dict.

{}

Returns:

Type Description
instance

The initialized component.

Source code in deep_river/base.py
def initialize_module(self, x: dict | pd.DataFrame, **kwargs):
    """
    Parameters
    ----------
    module
      The instance or class or callable to be initialized, e.g.
      ``self.module``.
    kwargs : dict
      The keyword arguments to initialize the instance or class. Can be an
      empty dict.
    Returns
    -------
    instance
      The initialized component.
    """
    torch.manual_seed(self.seed)
    if isinstance(x, Dict):
        n_features = len(x)
    elif isinstance(x, pd.DataFrame):
        n_features = len(x.columns)

    if not isinstance(self.module_cls, torch.nn.Module):
        self.module = self.module_cls(
            n_features=n_features,
            **self._filter_kwargs(self.module_cls, kwargs),
        )

    self.module.to(self.device)
    self.optimizer = self.optimizer_func(self.module.parameters(), lr=self.lr)
    self.module_initialized = True

    self._get_input_output_layers(n_features=n_features)

learn_many

learn_many(X: DataFrame) -> None

Performs one step of training with a batch of examples.

Parameters:

Name Type Description Default
X DataFrame

Input batch of examples.

required
Source code in deep_river/anomaly/ae.py
def learn_many(self, X: pd.DataFrame) -> None:
    """
    Performs one step of training with a batch of examples.

    Parameters
    ----------
    X
        Input batch of examples.
    """
    if not self.module_initialized:

        self._update_observed_features(X)
        self.initialize_module(x=X, **self.kwargs)

    self._adapt_input_dim(X)
    X_t = df2tensor(X, features=self.observed_features, device=self.device)
    self._learn(X_t)

learn_one

learn_one(x: dict, y: Any = None, **kwargs) -> None

Performs one step of training with a single example.

Parameters:

Name Type Description Default
x dict

Input example.

required
**kwargs
{}
Source code in deep_river/anomaly/ae.py
def learn_one(self, x: dict, y: Any = None, **kwargs) -> None:
    """
    Performs one step of training with a single example.

    Parameters
    ----------
    x
        Input example.

    **kwargs
    """
    if not self.module_initialized:
        self._update_observed_features(x)
        self.initialize_module(x=x, **self.kwargs)
    self._adapt_input_dim(x)
    self._learn(dict2tensor(x, features=self.observed_features, device=self.device))

score_many

score_many(X: DataFrame) -> ndarray

Returns an anomaly score for the provided batch of examples in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x

Input batch of examples.

required

Returns:

Type Description
float

Anomaly scores for the given batch of examples. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_many(self, X: pd.DataFrame) -> np.ndarray:
    """
    Returns an anomaly score for the provided batch of examples in
    the form of the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input batch of examples.

    Returns
    -------
    float
        Anomaly scores for the given batch of examples. Larger values
        indicate more anomalous examples.
    """
    if not self.module_initialized:
        self._update_observed_features(X)
        self.initialize_module(x=X, **self.kwargs)

    self._adapt_input_dim(X)
    X_t = df2tensor(X, features=self.observed_features, device=self.device)

    self.module.eval()
    with torch.inference_mode():
        X_pred = self.module(X_t)
    loss = torch.mean(
        self.loss_func(X_pred, X_t, reduction="none"),
        dim=list(range(1, X_t.dim())),
    )
    score = loss.cpu().detach().numpy()
    return score

score_one

score_one(x: dict) -> float

Returns an anomaly score for the provided example in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x dict

Input example.

required

Returns:

Type Description
float

Anomaly score for the given example. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_one(self, x: dict) -> float:
    """
    Returns an anomaly score for the provided example in the form of
    the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input example.

    Returns
    -------
    float
        Anomaly score for the given example. Larger values indicate
        more anomalous examples.

    """

    if not self.module_initialized:
        self._update_observed_features(x)
        self.initialize_module(x=x, **self.kwargs)

    self._adapt_input_dim(x)

    x_t = dict2tensor(x, features=self.observed_features, device=self.device)
    self.module.eval()
    with torch.inference_mode():
        x_pred = self.module(x_t)
    loss = self.loss_func(x_pred, x_t).item()
    return loss

AutoencoderInitialized

AutoencoderInitialized(
    module: Module,
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 0.001,
    is_feature_incremental: bool = False,
    device: str = "cpu",
    seed: int = 42,
    **kwargs
)

Bases: DeepEstimatorInitialized, AnomalyDetector

Represents an initialized autoencoder for anomaly detection and feature learning.

This class is built upon the DeepEstimatorInitialized and AnomalyDetector base classes. It provides methods for performing unsupervised learning through an autoencoder mechanism. The primary objective of the class is to train the autoencoder on input data and compute anomaly scores based on the reconstruction error. It supports learning on individual examples or entire batches of data.

Attributes:

Name Type Description
is_feature_incremental bool

Indicates whether the model is designed to increment features dynamically.

module Module

The PyTorch model representing the autoencoder architecture.

loss_fn Union[str, Callable]

Specifies the loss function to compute the reconstruction error.

optimizer_fn Union[str, Callable]

Specifies the optimizer to be used for training the autoencoder.

lr float

The learning rate for optimization.

device str

The device on which the model is loaded and trained (e.g., "cpu", "cuda").

seed int

Random seed for ensuring reproducibility.

Methods:

Name Description
learn_many

Performs one step of training with a batch of examples.

learn_one

Performs one step of training with a single example.

score_many

Returns an anomaly score for the provided batch of examples in

score_one

Returns an anomaly score for the provided example in the form of

Source code in deep_river/anomaly/ae.py
def __init__(
    self,
    module: torch.nn.Module,
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 1e-3,
    is_feature_incremental: bool = False,
    device: str = "cpu",
    seed: int = 42,
    **kwargs,
):
    super().__init__(
        module=module,
        loss_fn=loss_fn,
        optimizer_fn=optimizer_fn,
        lr=lr,
        is_feature_incremental=is_feature_incremental,
        device=device,
        seed=seed,
        **kwargs,
    )
    self.is_feature_incremental = is_feature_incremental

learn_many

learn_many(X: DataFrame) -> None

Performs one step of training with a batch of examples.

Parameters:

Name Type Description Default
X DataFrame

Input batch of examples.

required
Source code in deep_river/anomaly/ae.py
def learn_many(self, X: pd.DataFrame) -> None:
    """
    Performs one step of training with a batch of examples.

    Parameters
    ----------
    X
        Input batch of examples.
    """

    self._update_observed_features(X)
    X_t = self._df2tensor(X)
    self._learn(X_t)

learn_one

learn_one(x: dict, y: Any = None, **kwargs) -> None

Performs one step of training with a single example.

Parameters:

Name Type Description Default
x dict

Input example.

required
**kwargs
{}
Source code in deep_river/anomaly/ae.py
def learn_one(self, x: dict, y: Any = None, **kwargs) -> None:
    """
    Performs one step of training with a single example.

    Parameters
    ----------
    x
        Input example.

    **kwargs
    """
    self._update_observed_features(x)
    self._learn(self._dict2tensor(x))

score_many

score_many(X: DataFrame) -> ndarray

Returns an anomaly score for the provided batch of examples in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x

Input batch of examples.

required

Returns:

Type Description
float

Anomaly scores for the given batch of examples. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_many(self, X: pd.DataFrame) -> np.ndarray:
    """
    Returns an anomaly score for the provided batch of examples in
    the form of the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input batch of examples.

    Returns
    -------
    float
        Anomaly scores for the given batch of examples. Larger values
        indicate more anomalous examples.
    """
    self._update_observed_features(X)
    x_t = self._df2tensor(X)

    self.module.eval()
    with torch.inference_mode():
        x_pred = self.module(x_t)
    loss = torch.mean(
        self.loss_func(x_pred, x_t, reduction="none"),
        dim=list(range(1, x_t.dim())),
    )
    score = loss.cpu().detach().numpy()
    return score

score_one

score_one(x: dict) -> float

Returns an anomaly score for the provided example in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x dict

Input example.

required

Returns:

Type Description
float

Anomaly score for the given example. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_one(self, x: dict) -> float:
    """
    Returns an anomaly score for the provided example in the form of
    the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input example.

    Returns
    -------
    float
        Anomaly score for the given example. Larger values indicate
        more anomalous examples.

    """

    self._update_observed_features(x)
    x_t = self._dict2tensor(x)
    self.module.eval()
    with torch.inference_mode():
        x_pred = self.module(x_t)
    loss = self.loss_func(x_pred, x_t).item()
    return loss