Skip to content

rolling_regressor

RollingRegressor(module, loss_fn='mse', optimizer_fn='sgd', lr=0.001, window_size=10, append_predict=False, device='cpu', seed=42, **kwargs)

Bases: RollingDeepEstimator, Regressor

Wrapper that feeds a sliding window of the most recent examples to the wrapped PyTorch regression model.

PARAMETER DESCRIPTION
module

Torch Module that builds the autoencoder to be wrapped. The Module should accept parameter n_features so that the returned model's input shape can be determined based on the number of features in the initial training example.

TYPE: Type[Module]

loss_fn

Loss function to be used for training the wrapped model. Can be a loss function provided by torch.nn.functional or one of the following: 'mse', 'l1', 'cross_entropy', 'binary_crossentropy', 'smooth_l1', 'kl_div'.

TYPE: Union[str, Callable] DEFAULT: 'mse'

optimizer_fn

Optimizer to be used for training the wrapped model. Can be an optimizer class provided by torch.optim or one of the following: "adam", "adam_w", "sgd", "rmsprop", "lbfgs".

TYPE: Union[str, Callable] DEFAULT: 'sgd'

lr

Learning rate of the optimizer.

TYPE: float DEFAULT: 0.001

device

Device to run the wrapped model on. Can be "cpu" or "cuda".

TYPE: str DEFAULT: 'cpu'

seed

Random seed to be used for training the wrapped model.

TYPE: int DEFAULT: 42

window_size

Number of recent examples to be fed to the wrapped model at each step.

TYPE: int DEFAULT: 10

append_predict

Whether to append inputs passed for prediction to the rolling window.

TYPE: bool DEFAULT: False

**kwargs

Parameters to be passed to the Module or the optimizer.

DEFAULT: {}

learn_one(x, y, **kwargs)

Performs one step of training with the sliding window of the most recent examples.

PARAMETER DESCRIPTION
x

Input example.

TYPE: dict

y

Target value.

TYPE: RegTarget

RETURNS DESCRIPTION
RollingRegressor

The regressor itself.

predict_one(x)

Predicts the target value for the current sliding window of most recent examples.

PARAMETER DESCRIPTION
x

Input example.

TYPE: dict

RETURNS DESCRIPTION
RegTarget

Predicted target value.