Skip to content

multioutput

Classes:

Name Description
MultiTargetRegressor

Incremental multi-target regression wrapper for PyTorch modules.

MultiTargetRegressor

MultiTargetRegressor(
    module: Module,
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    is_feature_incremental: bool = False,
    is_target_incremental: bool = False,
    lr: float = 0.001,
    device: str = "cpu",
    seed: int = 42,
    **kwargs
)

Bases: MultiTargetRegressor, DeepEstimator

Incremental multi-target regression wrapper for PyTorch modules.

This estimator adapts a torch.nn.Module to the :mod:river streaming API for multi‑target (a.k.a. multi‑output) regression. It optionally supports feature‑incremental learning (dynamic growth of the input layer when new feature names appear) as provided by :class:deep_river.base.DeepEstimator and additionally (optionally) target‑incremental learning: if new target names appear during the stream, the output layer can be expanded on‑the‑fly so the model natively handles the enlarged target vector.

Targets are tracked via an ordered :class:~sortedcontainers.SortedSet to guarantee deterministic ordering between training and prediction. Incoming target dictionaries / frames are converted into dense tensors with columns arranged according to the observed target name order. Missing targets (when the model has been expanded but a prior sample omits some target) are imputed with 0.0.

Parameters:

Name Type Description Default
module Module

PyTorch module producing an output tensor of shape (N, T) where T is the current number of target variables.

required
loss_fn str | Callable

Loss identifier or custom callable passed through :func:deep_river.utils.get_loss_fn.

'mse'
optimizer_fn str | Callable

Optimizer identifier (e.g. 'adam', 'sgd') or factory / class.

'sgd'
is_feature_incremental bool

If True, unseen feature names trigger expansion of the first trainable layer (see :class:DeepEstimator).

False
is_target_incremental bool

If True, unseen target names trigger expansion of the last trainable layer. Expansion preserves existing weights and initialises new units with small random values.

False
lr float

Learning rate.

1e-3
device str

Torch device (e.g. 'cuda').

'cpu'
seed int

Random seed for reproducibility.

42
**kwargs

Extra arguments stored for persistence / cloning.

{}

Examples:

>>> import torch
>>> from torch import nn
>>> from deep_river.regression.multioutput import MultiTargetRegressor
>>> class TinyMultiNet(nn.Module):
...     def __init__(self, n_features, n_outputs):
...         super().__init__()
...         self.net = nn.Sequential(
...             nn.Linear(n_features, 8),
...             nn.ReLU(),
...             nn.Linear(8, n_outputs)
...         )
...     def forward(self, x):
...         return self.net(x)
>>> model = MultiTargetRegressor(
...     module=TinyMultiNet(3, 2),
...     loss_fn='mse',
...     optimizer_fn='sgd',
...     is_feature_incremental=True,
...     is_target_incremental=True,
... )
>>> x = {'a': 1.0, 'b': 2.0, 'c': 3.0}
>>> y = {'y1': 10.0, 'y2': 20.0}
>>> _ = model.learn_one(x, y)
>>> model.predict_one(x)
{'y1': ..., 'y2': ...}
Notes
  • The module's last trainable leaf layer is treated as output layer for
  • If is_target_incremental is disabled, the number of outputs is fixed and encountering a new target name will only register it internally (the tensor conversion will still allocate a slot, but the model's output layer size will not change, possibly causing a mismatch). Therefore, enabling target incrementality is recommended for truly open‑world streams.

Methods:

Name Description
clone

Return a fresh estimator instance with (optionally) copied state.

draw

Render a (partial) computational graph of the wrapped model.

learn_many

Learn from a batch of multi-target instances.

learn_one

Learn from a single multi-target instance.

load

Load a previously saved estimator.

predict_many

Predict target values for multiple instances.

predict_one

Predict a dictionary of target values for a single instance.

save

Persist the estimator (architecture, weights, optimiser & runtime state).

Source code in deep_river/regression/multioutput.py
def __init__(
    self,
    module: torch.nn.Module,
    loss_fn: Union[str, Callable] = "mse",
    optimizer_fn: Union[str, Callable] = "sgd",
    is_feature_incremental: bool = False,
    is_target_incremental: bool = False,
    lr: float = 1e-3,
    device: str = "cpu",
    seed: int = 42,
    **kwargs,
):
    super().__init__(
        module=module,
        loss_fn=loss_fn,
        optimizer_fn=optimizer_fn,
        lr=lr,
        device=device,
        seed=seed,
        is_feature_incremental=is_feature_incremental,
        **kwargs,
    )
    self.is_target_incremental = is_target_incremental
    self.observed_targets: SortedSet[FeatureName] = SortedSet()

clone

clone(
    new_params=None,
    include_attributes: bool = False,
    copy_weights: bool = False,
)

Return a fresh estimator instance with (optionally) copied state.

Parameters:

Name Type Description Default
new_params dict | None

Parameter overrides for the cloned instance.

None
include_attributes bool

If True, runtime state (observed features, buffers) is also copied.

False
copy_weights bool

If True, model weights are copied (otherwise the module is re‑initialised).

False
Source code in deep_river/base.py
def clone(
    self,
    new_params=None,
    include_attributes: bool = False,
    copy_weights: bool = False,
):
    """Return a fresh estimator instance with (optionally) copied state.

    Parameters
    ----------
    new_params : dict | None
        Parameter overrides for the cloned instance.
    include_attributes : bool, default=False
        If True, runtime state (observed features, buffers) is also copied.
    copy_weights : bool, default=False
        If True, model weights are copied (otherwise the module is re‑initialised).
    """
    new_params = new_params or {}
    copy_weights = new_params.pop("copy_weights", copy_weights)

    params = {**self._get_all_init_params(), **new_params}

    if "module" not in new_params:
        params["module"] = self._rebuild_module()

    new_est = self.__class__(**self._filter_kwargs(self.__class__.__init__, params))

    if copy_weights and hasattr(self.module, "state_dict"):
        new_est.module.load_state_dict(self.module.state_dict())

    if include_attributes:
        new_est._restore_runtime_state(self._get_runtime_state())

    return new_est

draw

draw()

Render a (partial) computational graph of the wrapped model.

Imports graphviz and torchviz lazily. Raises an informative ImportError if the optional dependencies are not installed.

Source code in deep_river/base.py
def draw(self):  # type: ignore[override]
    """Render a (partial) computational graph of the wrapped model.

    Imports ``graphviz`` and ``torchviz`` lazily. Raises an informative
    ImportError if the optional dependencies are not installed.
    """
    try:  # pragma: no cover
        from torchviz import make_dot  # type: ignore
    except Exception as err:  # noqa: BLE001
        raise ImportError(
            "graphviz and torchviz must be installed to draw the model."
        ) from err

    first_parameter = next(self.module.parameters())
    input_shape = first_parameter.size()
    y_pred = self.module(torch.rand(input_shape))
    return make_dot(y_pred.mean(), params=dict(self.module.named_parameters()))

learn_many

learn_many(
    X: DataFrame,
    y: Union[
        DataFrame, Series, Mapping[str, Sequence[RegTarget]]
    ],
) -> None

Learn from a batch of multi-target instances.

Parameters:

Name Type Description Default
X DataFrame

Feature matrix (rows are samples, columns are feature names).

required
y DataFrame | Series | mapping

Target matrix. Preferred is a DataFrame with one column per target. A Series is interpreted as one target. A mapping of name -> list is converted into a DataFrame first.

required
Source code in deep_river/regression/multioutput.py
def learn_many(
    self,
    X: pd.DataFrame,
    y: Union[pd.DataFrame, pd.Series, Mapping[str, Sequence[RegTarget]]],
) -> None:
    """Learn from a batch of multi-target instances.

    Parameters
    ----------
    X : pandas.DataFrame
        Feature matrix (rows are samples, columns are feature names).
    y : pandas.DataFrame | pandas.Series | mapping
        Target matrix. Preferred is a DataFrame with one column per target.
        A Series is interpreted as *one* target. A mapping of ``name -> list``
        is converted into a DataFrame first.
    """
    self._update_observed_features(X)
    y_df = self._coerce_targets_to_frame(y)
    self._update_observed_targets(y_df)

    x_t = self._df2tensor(X)
    y_t = self._multi_target_frame_to_tensor(y_df)
    self._learn(x_t, y_t)

learn_one

learn_one(
    x: dict, y: dict[FeatureName, RegTarget], **kwargs
) -> None

Learn from a single multi-target instance.

Parameters:

Name Type Description Default
x dict[str, float]

Feature mapping.

required
y dict[str, float]

Mapping of target name -> target value.

required
**kwargs

Ignored (kept for signature compatibility / future hooks).

{}
Source code in deep_river/regression/multioutput.py
def learn_one(
    self,
    x: dict,
    y: dict[FeatureName, RegTarget],
    **kwargs,
) -> None:
    """Learn from a single multi-target instance.

    Parameters
    ----------
    x : dict[str, float]
        Feature mapping.
    y : dict[str, float]
        Mapping of target name -> target value.
    **kwargs
        Ignored (kept for signature compatibility / future hooks).
    """
    self._update_observed_features(x)
    self._update_observed_targets(y)
    x_t = self._dict2tensor(dict(x))
    y_t = self._single_target_dict_to_tensor(y)
    self._learn(x_t, y_t)

load classmethod

load(filepath: Union[str, Path])

Load a previously saved estimator.

The method reconstructs the estimator class, its wrapped module, optimiser state and runtime information (feature names, buffers, etc.).

Source code in deep_river/base.py
@classmethod
def load(cls, filepath: Union[str, Path]):
    """Load a previously saved estimator.

    The method reconstructs the estimator class, its wrapped module, optimiser
    state and runtime information (feature names, buffers, etc.).
    """
    with open(filepath, "rb") as f:
        state = pickle.load(f)

    estimator_cls = cls._import_from_path(state["estimator_class"])
    init_params = state["init_params"]

    # Rebuild module if needed
    if "module" in init_params and isinstance(init_params["module"], dict):
        module_info = init_params.pop("module")
        module_cls = cls._import_from_path(module_info["class"])
        module = module_cls(
            **cls._filter_kwargs(module_cls.__init__, module_info["kwargs"])
        )
        if state.get("model_state_dict"):
            module.load_state_dict(state["model_state_dict"])
        init_params["module"] = module

    estimator = estimator_cls(
        **cls._filter_kwargs(estimator_cls.__init__, init_params)
    )

    if state.get("optimizer_state_dict") and hasattr(estimator, "optimizer"):
        try:
            estimator.optimizer.load_state_dict(
                state["optimizer_state_dict"]  # type: ignore[arg-type]
            )
        except Exception:  # noqa: E722
            pass

    estimator._restore_runtime_state(state.get("runtime_state", {}))
    return estimator

predict_many

predict_many(X: DataFrame) -> DataFrame

Predict target values for multiple instances.

Returns:

Type Description
DataFrame

DataFrame whose columns follow the ordering of observed_targets.

Source code in deep_river/regression/multioutput.py
def predict_many(self, X: pd.DataFrame) -> pd.DataFrame:
    """Predict target values for multiple instances.

    Returns
    -------
    pandas.DataFrame
        DataFrame whose columns follow the ordering of ``observed_targets``.
    """
    self._update_observed_features(X)
    x_t = self._df2tensor(X)
    self.module.eval()
    with torch.inference_mode():
        y_pred = self.module(x_t)
        if y_pred.is_cuda:
            y_pred = y_pred.cpu()
    # Ensure 2D
    if y_pred.dim() == 1:
        y_pred = y_pred.view(-1, 1)
    cols = list(self.observed_targets)
    # Truncate or pad columns if dimensions drift (defensive)
    if y_pred.shape[1] < len(cols):
        pad = torch.zeros(
            (y_pred.shape[0], len(cols) - y_pred.shape[1]),
            dtype=y_pred.dtype,
        )
        y_pred = torch.cat([y_pred, pad], dim=1)
    elif y_pred.shape[1] > len(cols):
        extra = [f"__extra_{i}" for i in range(y_pred.shape[1] - len(cols))]
        cols = cols + extra
    return pd.DataFrame(y_pred.numpy(), columns=cols)

predict_one

predict_one(x: dict) -> dict[FeatureName, RegTarget]

Predict a dictionary of target values for a single instance.

Source code in deep_river/regression/multioutput.py
def predict_one(self, x: dict) -> dict[FeatureName, RegTarget]:
    """Predict a dictionary of target values for a single instance."""
    self._update_observed_features(x)
    x_t = self._dict2tensor(dict(x))
    self.module.eval()
    with torch.inference_mode():
        y_pred_t = self.module(x_t).squeeze(0)
        if y_pred_t.dim() == 0:  # single value fallback
            y_pred_t = y_pred_t.view(1)
        if y_pred_t.is_cuda:
            y_pred_t = y_pred_t.cpu()
        y_list: list[float] = [float(v) for v in y_pred_t.tolist()]
    return {
        cast(FeatureName, t): cast(
            RegTarget, (y_list[i] if i < len(y_list) else float("nan"))
        )
        for i, t in enumerate(self.observed_targets)
    }

save

save(filepath: Union[str, Path]) -> None

Persist the estimator (architecture, weights, optimiser & runtime state).

Parameters:

Name Type Description Default
filepath str | Path

Destination file. Parent directories are created automatically.

required
Source code in deep_river/base.py
def save(self, filepath: Union[str, Path]) -> None:
    """Persist the estimator (architecture, weights, optimiser & runtime state).

    Parameters
    ----------
    filepath : str | Path
        Destination file. Parent directories are created automatically.
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)

    state = {
        "estimator_class": f"{type(self).__module__}.{type(self).__name__}",
        "init_params": self._get_all_init_params(),
        "model_state_dict": getattr(self.module, "state_dict", lambda: {})(),
        "optimizer_state_dict": getattr(self.optimizer, "state_dict", lambda: {})(),
        "runtime_state": self._get_runtime_state(),
    }

    with open(filepath, "wb") as f:
        pickle.dump(state, f)