layer_adaptation
expand_weights(weights, axis, n_dims_to_add, init_fn, n_subparams=1)
¶
Expands weights
along the given axis by n_dims_to_add
.
The expanded weights are created by evenly splitting the
original weights into its subparams and appending new weights
to them.
PARAMETER | DESCRIPTION |
---|---|
weights
|
Parameter to be expanded.
TYPE:
|
axis
|
Axis along which to expand the parameter.
TYPE:
|
n_dims_to_add
|
Number of dims to add to each sub-parameter within the parameter.
TYPE:
|
init_fn
|
Function to initiate the new weights with.
TYPE:
|
n_subparams
|
Number of sub-parameters contained in the parameter.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
weights_expanded
|
The expanded weights as a pytorch parameter. |
get_expansion_instructions(param_shapes)
¶
Returns a dictionary containing information on how each parameter of a layer contained in param_shapes corresponds to the input and output dimensionality given its shape string.
PARAMETER | DESCRIPTION |
---|---|
param_shapes
|
Dictionary containing all parameters of a layer as keys and their corresponding shape strings as values.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
instructions
|
Dictionary specifying which axes of each parameter have to be altered to modify the input- or output dimensionality as well as the number of sub-parameters contained in the axes. |
get_in_out_axes(shape_str)
¶
Returns a dictionary containing information on how a specific parameter's axis sizes correspond to the input and output dimensionality given its shape string.
PARAMETER | DESCRIPTION |
---|---|
shape_str
|
String specifying the shape of a parameter.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
axes
|
Dictionary specifying which axes have to be altered to modify the input- or output dimensionality as well as the number of sub-parameters contained in the axes. |