Skip to content

layer_adaptation

expand_weights(weights, axis, n_dims_to_add, init_fn, n_subparams=1)

Expands weights along the given axis by n_dims_to_add. The expanded weights are created by evenly splitting the original weights into its subparams and appending new weights to them.

PARAMETER DESCRIPTION
weights

Parameter to be expanded.

TYPE: Tensor

axis

Axis along which to expand the parameter.

TYPE: int

n_dims_to_add

Number of dims to add to each sub-parameter within the parameter.

TYPE: int

init_fn

Function to initiate the new weights with.

TYPE: Callable

n_subparams

Number of sub-parameters contained in the parameter.

TYPE: int DEFAULT: 1

RETURNS DESCRIPTION
weights_expanded

The expanded weights as a pytorch parameter.

get_expansion_instructions(param_shapes)

Returns a dictionary containing information on how each parameter of a layer contained in param_shapes corresponds to the input and output dimensionality given its shape string.

PARAMETER DESCRIPTION
param_shapes

Dictionary containing all parameters of a layer as keys and their corresponding shape strings as values.

TYPE: Dict

RETURNS DESCRIPTION
instructions

Dictionary specifying which axes of each parameter have to be altered to modify the input- or output dimensionality as well as the number of sub-parameters contained in the axes.

get_in_out_axes(shape_str)

Returns a dictionary containing information on how a specific parameter's axis sizes correspond to the input and output dimensionality given its shape string.

PARAMETER DESCRIPTION
shape_str

String specifying the shape of a parameter.

TYPE: str

RETURNS DESCRIPTION
axes

Dictionary specifying which axes have to be altered to modify the input- or output dimensionality as well as the number of sub-parameters contained in the axes.