Skip to content

regressor

Regressor(module, loss_fn='mse', optimizer_fn='sgd', lr=0.001, device='cpu', seed=42, **kwargs)

Bases: DeepEstimator, MiniBatchRegressor

Wrapper for PyTorch regression models that enables compatibility with River.

PARAMETER DESCRIPTION
module

Torch Module that builds the autoencoder to be wrapped. The Module should accept parameter n_features so that the returned model's input shape can be determined based on the number of features in the initial training example.

TYPE: Type[Module]

loss_fn

Loss function to be used for training the wrapped model. Can be a loss function provided by torch.nn.functional or one of the following: 'mse', 'l1', 'cross_entropy', 'binary_crossentropy', 'smooth_l1', 'kl_div'.

TYPE: Union[str, Callable] DEFAULT: 'mse'

optimizer_fn

Optimizer to be used for training the wrapped model. Can be an optimizer class provided by torch.optim or one of the following: "adam", "adam_w", "sgd", "rmsprop", "lbfgs".

TYPE: Union[str, Callable] DEFAULT: 'sgd'

lr

Learning rate of the optimizer.

TYPE: float DEFAULT: 0.001

device

Device to run the wrapped model on. Can be "cpu" or "cuda".

TYPE: str DEFAULT: 'cpu'

seed

Random seed to be used for training the wrapped model.

TYPE: int DEFAULT: 42

**kwargs

Parameters to be passed to the Module or the optimizer.

DEFAULT: {}

Examples:

learn_many(X, y)

Performs one step of training with a batch of examples.

PARAMETER DESCRIPTION
x

Input examples.

y

Target values.

TYPE: Series

RETURNS DESCRIPTION
Regressor

The regressor itself.

learn_one(x, y, **kwargs)

Performs one step of training with a single example.

PARAMETER DESCRIPTION
x

Input example.

TYPE: dict

y

Target value.

TYPE: RegTarget

RETURNS DESCRIPTION
Regressor

The regressor itself.

predict_many(X)

Predicts the target value for a batch of examples.

PARAMETER DESCRIPTION
x

Input examples.

RETURNS DESCRIPTION
List

Predicted target values.

predict_one(x)

Predicts the target value for a single example.

PARAMETER DESCRIPTION
x

Input example.

TYPE: dict

RETURNS DESCRIPTION
RegTarget

Predicted target value.