Skip to content

params

Functions:

Name Description
get_activation_fn

Returns the requested activation function as a nn.Module class.

get_init_fn

Returns the requested init function.

get_loss_fn

Returns the requested loss function as a function.

get_optim_fn

Returns the requested optimizer as a nn.Module class.

get_activation_fn

get_activation_fn(
    activation_fn: Union[str, Callable],
) -> Callable

Returns the requested activation function as a nn.Module class.

Parameters:

Name Type Description Default
activation_fn Union[str, Callable]

The activation function to fetch. Can be a string or a nn.Module class.

required

Returns:

Type Description
Callable

The class of the requested activation function.

Source code in deep_river/utils/params.py
def get_activation_fn(activation_fn: Union[str, Callable]) -> Callable:
    """Returns the requested activation function as a nn.Module class.

    Parameters
    ----------
    activation_fn
        The activation function to fetch. Can be a string or a nn.Module class.

    Returns
    -------
    Callable
        The class of the requested activation function.
    """
    err = ValueError(
        BASE_PARAM_ERROR.format("activation function", activation_fn, "nn.Module")
    )
    if isinstance(activation_fn, str):
        try:
            activation_fn = ACTIVATION_FNS[activation_fn]
        except KeyError:
            raise err
    elif not isinstance(activation_fn(), nn.Module):
        raise err
    return activation_fn

get_init_fn

get_init_fn(init_fn)

Returns the requested init function.

Parameters:

Name Type Description Default
init_fn

The init function to fetch. Must be one of ["xavier_uniform", "uniform", "kaiming_uniform"].

required

Returns:

Type Description
Callable

The class of the requested activation function.

Source code in deep_river/utils/params.py
def get_init_fn(init_fn):
    """Returns the requested init function.

    Parameters
    ----------
    init_fn
        The init function to fetch. Must be one of ["xavier_uniform",
        "uniform", "kaiming_uniform"].

    Returns
    -------
    Callable
        The class of the requested activation function.
    """
    init_fn_ = INIT_FNS.get(init_fn, "xavier_uniform")
    if init_fn.startswith("xavier"):

        def result(weight, activation_fn):
            return init_fn_(weight, gain=nn.init.calculate_gain(activation_fn))

    elif init_fn.startswith("kaiming"):

        def result(weight, activation_fn):
            return init_fn_(weight, nonlinearity=activation_fn)

    elif init_fn == "uniform":

        def result(weight, activation_fn):
            return 0

    else:

        def result(weight, activation_fn):
            return init_fn_(weight)

    return result

get_loss_fn

get_loss_fn(loss_fn: Union[str, Callable]) -> Callable

Returns the requested loss function as a function.

Parameters:

Name Type Description Default
loss_fn Union[str, Callable]

The loss function to fetch. Can be a string or a function.

required

Returns:

Type Description
Callable

The function of the requested loss function.

Source code in deep_river/utils/params.py
def get_loss_fn(loss_fn: Union[str, Callable]) -> Callable:
    """Returns the requested loss function as a function.

    Parameters
    ----------
    loss_fn
        The loss function to fetch. Can be a string or a function.

    Returns
    -------
    Callable
        The function of the requested loss function.
    """
    err = ValueError(BASE_PARAM_ERROR.format("loss function", loss_fn, "function"))
    if isinstance(loss_fn, str):
        try:
            loss_fn = LOSS_FNS[loss_fn]
        except KeyError:
            raise err
    elif not callable(loss_fn):
        raise err
    return loss_fn

get_optim_fn

get_optim_fn(optim_fn: Union[str, Callable]) -> Callable

Returns the requested optimizer as a nn.Module class.

Parameters:

Name Type Description Default
optim_fn Union[str, Callable]

The optimizer to fetch. Can be a string or a nn.Module class.

required

Returns:

Type Description
Callable

The class of the requested optimizer.

Source code in deep_river/utils/params.py
def get_optim_fn(optim_fn: Union[str, Callable]) -> Callable:
    """Returns the requested optimizer as a nn.Module class.

    Parameters
    ----------
    optim_fn
        The optimizer to fetch. Can be a string or a nn.Module class.


    Returns
    -------
    Callable
        The class of the requested optimizer.
    """
    err = ValueError(BASE_PARAM_ERROR.format("optimizer", optim_fn, "nn.Module"))
    if isinstance(optim_fn, str):
        try:
            optim_fn = OPTIMIZER_FNS[optim_fn]
        except KeyError:
            raise err

    elif not isinstance(
        optim_fn(params=[torch.empty(1)], lr=1e-3), torch.optim.Optimizer
    ):
        raise err
    return optim_fn