Skip to content

probability_weighted_ae

Classes:

Name Description
ProbabilityWeightedAutoencoder

Wrapper for PyTorch autoencoder models for anomaly detection that

ProbabilityWeightedAutoencoderInitialized

ProbabilityWeightedAutoencoder

ProbabilityWeightedAutoencoder(
    module: Type[Module],
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 0.001,
    device: str = "cpu",
    seed: int = 42,
    skip_threshold: float = 0.9,
    window_size=250,
    **kwargs
)

Bases: Autoencoder

Wrapper for PyTorch autoencoder models for anomaly detection that reduces the employed learning rate based on an outlier probability estimate of the input example as well as a threshold probability skip_threshold. If the outlier probability is above the threshold, the learning rate is reduced to less than 0. Given the probability estimate \(p_out\), the adjusted learning rate \(lr_adj\) is \(lr * 1 - ( rac{p_out}{skip_threshold})\).

Parameters:

Name Type Description Default
module Type[Module]

Torch Module that builds the autoencoder to be wrapped. The Module should accept parameter n_features so that the returned model's input shape can be determined based on the number of features in the initial training example.

required
loss_fn Union[str, Callable]

Loss function to be used for training the wrapped model. Can be a loss function provided by torch.nn.functional or one of the following: 'mse', 'l1', 'cross_entropy', 'binary_crossentropy', 'smooth_l1', 'kl_div'.

'mse'
optimizer_fn Union[str, Callable]

Optimizer to be used for training the wrapped model. Can be an optimizer class provided by torch.optim or one of the following: "adam", "adam_w", "sgd", "rmsprop", "lbfgs".

'sgd'
lr float

Base learning rate of the optimizer.

0.001
skip_threshold float

Threshold probability to use as a reference for the reduction of the base learning rate.

0.9
device str

Device to run the wrapped model on. Can be "cpu" or "cuda".

'cpu'
seed int

Random seed to be used for training the wrapped model.

42
**kwargs

Parameters to be passed to the module function aside from n_features.

Examples

{}
Rolling
required

Methods:

Name Description
clone

Clones the estimator.

draw

Draws the wrapped model.

initialize_module

Parameters

learn_one

Performs one step of training with a single example,

score_many

Returns an anomaly score for the provided batch of examples in

score_one

Returns an anomaly score for the provided example in the form of

Source code in deep_river/anomaly/probability_weighted_ae.py
def __init__(
    self,
    module: Type[torch.nn.Module],
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 1e-3,
    device: str = "cpu",
    seed: int = 42,
    skip_threshold: float = 0.9,
    window_size=250,
    **kwargs,
):
    super().__init__(
        module=module,
        loss_fn=loss_fn,
        optimizer_fn=optimizer_fn,
        lr=lr,
        device=device,
        seed=seed,
        **kwargs,
    )
    self.window_size = window_size
    self.skip_threshold = skip_threshold
    self.rolling_mean = utils.Rolling(stats.Mean(), window_size=window_size)
    self.rolling_var = utils.Rolling(stats.Var(), window_size=window_size)

clone

clone(
    new_params: dict[Any, Any] | None = None,
    include_attributes=False,
)

Clones the estimator.

Parameters:

Name Type Description Default
new_params dict[Any, Any] | None

New parameters to be passed to the cloned estimator.

None
include_attributes

If True, the attributes of the estimator will be copied to the cloned estimator. This is useful when the estimator is a transformer and the attributes are the learned parameters.

False

Returns:

Type Description
DeepEstimator

The cloned estimator.

Source code in deep_river/base.py
def clone(
    self,
    new_params: dict[Any, Any] | None = None,
    include_attributes=False,
):
    """Clones the estimator.

    Parameters
    ----------
    new_params
        New parameters to be passed to the cloned estimator.
    include_attributes
        If True, the attributes of the estimator will be copied to the
        cloned estimator. This is useful when the estimator is a
        transformer and the attributes are the learned parameters.

    Returns
    -------
    DeepEstimator
        The cloned estimator.
    """
    new_params = new_params or {}
    new_params.update(self.kwargs)
    new_params.update(self._get_params())
    new_params.update({"module": self.module_cls})

    clone = self.__class__(**new_params)
    if include_attributes:
        clone.__dict__.update(self.__dict__)
    return clone

draw

draw() -> Digraph

Draws the wrapped model.

Source code in deep_river/base.py
def draw(self) -> Digraph:
    """Draws the wrapped model."""
    first_parameter = next(self.module.parameters())
    input_shape = first_parameter.size()
    y_pred = self.module(torch.rand(input_shape))
    return make_dot(y_pred.mean(), params=dict(self.module.named_parameters()))

initialize_module

initialize_module(x: dict | DataFrame, **kwargs)

Parameters:

Name Type Description Default
module

The instance or class or callable to be initialized, e.g. self.module.

required
kwargs dict

The keyword arguments to initialize the instance or class. Can be an empty dict.

{}

Returns:

Type Description
instance

The initialized component.

Source code in deep_river/base.py
def initialize_module(self, x: dict | pd.DataFrame, **kwargs):
    """
    Parameters
    ----------
    module
      The instance or class or callable to be initialized, e.g.
      ``self.module``.
    kwargs : dict
      The keyword arguments to initialize the instance or class. Can be an
      empty dict.
    Returns
    -------
    instance
      The initialized component.
    """
    torch.manual_seed(self.seed)
    if isinstance(x, Dict):
        n_features = len(x)
    elif isinstance(x, pd.DataFrame):
        n_features = len(x.columns)

    if not isinstance(self.module_cls, torch.nn.Module):
        self.module = self.module_cls(
            n_features=n_features,
            **self._filter_kwargs(self.module_cls, kwargs),
        )

    self.module.to(self.device)
    self.optimizer = self.optimizer_func(self.module.parameters(), lr=self.lr)
    self.module_initialized = True

    self._get_input_output_layers(n_features=n_features)

learn_one

learn_one(x: dict, y: Any = None, **kwargs) -> None

Performs one step of training with a single example, scaling the employed learning rate based on the outlier probability estimate of the input example.

Parameters:

Name Type Description Default
**kwargs
{}
x dict

Input example.

required

Returns:

Type Description
ProbabilityWeightedAutoencoder

The autoencoder itself.

Source code in deep_river/anomaly/probability_weighted_ae.py
def learn_one(self, x: dict, y: Any = None, **kwargs) -> None:
    """
    Performs one step of training with a single example,
    scaling the employed learning rate based on the outlier
    probability estimate of the input example.

    Parameters
    ----------
    **kwargs
    x
        Input example.

    Returns
    -------
    ProbabilityWeightedAutoencoder
        The autoencoder itself.
    """
    if not self.module_initialized:
        self._update_observed_features(x)
        self.initialize_module(x=x, **self.kwargs)

    self._adapt_input_dim(x)
    x_t = dict2tensor(x, self.observed_features, device=self.device)

    self.module.train()
    x_pred = self.module(x_t)
    loss = self.loss_func(x_pred, x_t)
    self._apply_loss(loss)

score_many

score_many(X: DataFrame) -> ndarray

Returns an anomaly score for the provided batch of examples in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x

Input batch of examples.

required

Returns:

Type Description
float

Anomaly scores for the given batch of examples. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_many(self, X: pd.DataFrame) -> np.ndarray:
    """
    Returns an anomaly score for the provided batch of examples in
    the form of the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input batch of examples.

    Returns
    -------
    float
        Anomaly scores for the given batch of examples. Larger values
        indicate more anomalous examples.
    """
    if not self.module_initialized:
        self._update_observed_features(X)
        self.initialize_module(x=X, **self.kwargs)

    self._adapt_input_dim(X)
    X_t = df2tensor(X, features=self.observed_features, device=self.device)

    self.module.eval()
    with torch.inference_mode():
        X_pred = self.module(X_t)
    loss = torch.mean(
        self.loss_func(X_pred, X_t, reduction="none"),
        dim=list(range(1, X_t.dim())),
    )
    score = loss.cpu().detach().numpy()
    return score

score_one

score_one(x: dict) -> float

Returns an anomaly score for the provided example in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x dict

Input example.

required

Returns:

Type Description
float

Anomaly score for the given example. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_one(self, x: dict) -> float:
    """
    Returns an anomaly score for the provided example in the form of
    the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input example.

    Returns
    -------
    float
        Anomaly score for the given example. Larger values indicate
        more anomalous examples.

    """

    if not self.module_initialized:
        self._update_observed_features(x)
        self.initialize_module(x=x, **self.kwargs)

    self._adapt_input_dim(x)

    x_t = dict2tensor(x, features=self.observed_features, device=self.device)
    self.module.eval()
    with torch.inference_mode():
        x_pred = self.module(x_t)
    loss = self.loss_func(x_pred, x_t).item()
    return loss

ProbabilityWeightedAutoencoderInitialized

ProbabilityWeightedAutoencoderInitialized(
    module: Module,
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 0.001,
    device: str = "cpu",
    seed: int = 42,
    skip_threshold: float = 0.9,
    window_size=250,
    **kwargs
)

Bases: AutoencoderInitialized

Methods:

Name Description
learn_one

Performs one step of training with a single example,

score_many

Returns an anomaly score for the provided batch of examples in

score_one

Returns an anomaly score for the provided example in the form of

Source code in deep_river/anomaly/probability_weighted_ae.py
def __init__(
    self,
    module: torch.nn.Module,
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    lr: float = 1e-3,
    device: str = "cpu",
    seed: int = 42,
    skip_threshold: float = 0.9,
    window_size=250,
    **kwargs,
):
    warnings.warn(
        "This is deprecated and will be removed in future releases. "
        "Please instead use the AutoencoderInitialized class and "
        "initialize the module beforehand"
    )
    super().__init__(
        module=module,
        loss_fn=loss_fn,
        optimizer_fn=optimizer_fn,
        lr=lr,
        device=device,
        seed=seed,
        **kwargs,
    )
    self.window_size = window_size
    self.skip_threshold = skip_threshold
    self.rolling_mean = utils.Rolling(stats.Mean(), window_size=window_size)
    self.rolling_var = utils.Rolling(stats.Var(), window_size=window_size)

learn_one

learn_one(x: dict, y: Any = None, **kwargs) -> None

Performs one step of training with a single example, scaling the employed learning rate based on the outlier probability estimate of the input example.

Parameters:

Name Type Description Default
**kwargs
{}
x dict

Input example.

required

Returns:

Type Description
ProbabilityWeightedAutoencoder

The autoencoder itself.

Source code in deep_river/anomaly/probability_weighted_ae.py
def learn_one(self, x: dict, y: Any = None, **kwargs) -> None:
    """
    Performs one step of training with a single example,
    scaling the employed learning rate based on the outlier
    probability estimate of the input example.

    Parameters
    ----------
    **kwargs
    x
        Input example.

    Returns
    -------
    ProbabilityWeightedAutoencoder
        The autoencoder itself.
    """

    self._update_observed_features(x)
    x_t = self._dict2tensor(x)

    self.module.train()
    x_pred = self.module(x_t)
    loss = self.loss_func(x_pred, x_t)
    self._apply_loss(loss)

score_many

score_many(X: DataFrame) -> ndarray

Returns an anomaly score for the provided batch of examples in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x

Input batch of examples.

required

Returns:

Type Description
float

Anomaly scores for the given batch of examples. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_many(self, X: pd.DataFrame) -> np.ndarray:
    """
    Returns an anomaly score for the provided batch of examples in
    the form of the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input batch of examples.

    Returns
    -------
    float
        Anomaly scores for the given batch of examples. Larger values
        indicate more anomalous examples.
    """
    self._update_observed_features(X)
    x_t = self._df2tensor(X)

    self.module.eval()
    with torch.inference_mode():
        x_pred = self.module(x_t)
    loss = torch.mean(
        self.loss_func(x_pred, x_t, reduction="none"),
        dim=list(range(1, x_t.dim())),
    )
    score = loss.cpu().detach().numpy()
    return score

score_one

score_one(x: dict) -> float

Returns an anomaly score for the provided example in the form of the autoencoder's reconstruction error.

Parameters:

Name Type Description Default
x dict

Input example.

required

Returns:

Type Description
float

Anomaly score for the given example. Larger values indicate more anomalous examples.

Source code in deep_river/anomaly/ae.py
def score_one(self, x: dict) -> float:
    """
    Returns an anomaly score for the provided example in the form of
    the autoencoder's reconstruction error.

    Parameters
    ----------
    x
        Input example.

    Returns
    -------
    float
        Anomaly score for the given example. Larger values indicate
        more anomalous examples.

    """

    self._update_observed_features(x)
    x_t = self._dict2tensor(x)
    self.module.eval()
    with torch.inference_mode():
        x_pred = self.module(x_t)
    loss = self.loss_func(x_pred, x_t).item()
    return loss