fluke.utils.model
¶
This submodule provides utilities for pytorch model manipulation.
Classes included in fluke.utils.model
Mixin class for model interpolation. |
|
Linear layer with gloabl and local weights. |
|
Conv2d layer with gloabl and local weights. |
|
LSTM layer with gloabl and local weights. |
|
Embedding layer with gloabl and local weights. |
|
BatchNorm2d layer with gloabl and local weights. |
|
Wrapper class to get the output of all layers in a model. |
Functions included in fluke.utils.model
Compute the difference between two model state dictionaries. |
|
Merge two models using a linear interpolation. |
|
Set model interpolation constant. |
|
Get the output shape of a model given the shape of the input. |
|
Get the local model state dictionary. |
|
Get the global model state dictionary. |
|
Mix two networks using a linear interpolation. |
|
Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm |
|
Load a state dictionary into a model. |
|
Check if the models fit in the memory of the device. |
Classes¶
class fluke.utils.model.MMMixin
- class fluke.utils.model.MMMixin(*args, **kwargs)[source]¶
Mixin class for model interpolation. This class provides the necessary methods to interpolate between two models. This mixin class must be used as a parent class for the PyTorch modules that need to be interpolated.
Tip
Ideally, when using this mixin to implement a new class
M
, this should be mixed with a classC
that extendstorch.nn.Module
and the interpolation of the parameters should happen between the parameters inC
and a new set of parameters defined inA
. This type of multiple inheritance must have as first parent the classMMMixin
and as second parent a class that extendstorch.nn.Module
. For example:1# C is a class that extends torch.nn.Module 2class M(MMMixin, C): 3 def __init__(self, *args, **kwargs: dict[str, Any]): 4 super().__init__(*args, **kwargs) 5 self.weight_local = nn.Parameter(torch.zeros_like(self.weight))
In this case, the default implementation of the method
get_weight()
will work and will interpolate between theweight
and theweight_local
attribute of the module.- get_lambda() float [source]¶
Get the interpolation constant.
- Returns:
The interpolation constant.
- Return type:
- get_weight() Tensor [source]¶
Get the interpolated weights of the layer or module according to the interpolation constant
lam
. The default implementation assumes that the layer or module has aweight
attribute and aweight_local
attribute that are both tensors of the same shape. The interpolated weights are computed as:w = (1 - self.lam) * self.weight + self.lam * self.weight_local
.- Returns:
The interpolated weights.
- Return type:
torch.Tensor
class fluke.utils.model.LinesLinear
- class fluke.utils.model.LinesLinear(*args, **kwargs: dict[str, Any])[source]¶
Bases:
MMMixin
,Linear
Linear layer with gloabl and local weights. The weights are interpolated using the interpolation constant
lam
. Thus, the forward pass of this layer will use the interpolated weights.Note
The global weights are the “default” weights of the
torch.nn.Linear
layer, while the local ones are in the submoduleweight_local
(andbias_local
).- weight_local¶
The local weights.
- Type:
torch.Tensor
- bias_local¶
The local bias.
- Type:
torch.Tensor
class fluke.utils.model.LinesConv2d
- class fluke.utils.model.LinesConv2d(*args, **kwargs: dict[str, Any])[source]¶
Bases:
MMMixin
,Conv2d
Conv2d layer with gloabl and local weights. The weights are interpolated using the interpolation constant
lam
. Thus, the forward pass of this layer will use the interpolated weights.Note
The global weights are the “default” weights of the
torch.nn.Conv2d
layer, while the local ones are in the submoduleweight_local
(andbias_local
).- weight_local¶
The local weights.
- Type:
torch.Tensor
- bias_local¶
The local bias.
- Type:
torch.Tensor
class fluke.utils.model.LinesLSTM
- class fluke.utils.model.LinesLSTM(*args, **kwargs: dict[str, Any])[source]¶
Bases:
MMMixin
,LSTM
LSTM layer with gloabl and local weights. The weights are interpolated using the interpolation constant
lam
. Thus, the forward pass of this layer will use the interpolated weights.Note
The global weights are the “default” weights of the
torch.nn.LSTM
layer, while the local ones are in the submodulesweight_hh_l{layer}_local
andweight_ih_l{layer}_local
, wherelayer
is the layer number. Similar considerations apply to the biases.Caution
This class may not work properly an all devices. If you encounter any issues, please open an issue in the repository.
- weight_hh_l{layer}_local
The local hidden-hidden weights of layer
layer
.- Type:
torch.Tensor
- weight_ih_l{layer}_local
The local input-hidden weights of layer
layer
.- Type:
torch.Tensor
- bias_hh_l{layer}_local
The local hidden-hidden biases of layer
layer
.- Type:
torch.Tensor
- bias_ih_l{layer}_local
The local input-hidden biases of layer
layer
.- Type:
torch.Tensor
class fluke.utils.model.LinesEmbedding
- class fluke.utils.model.LinesEmbedding(*args, **kwargs: dict[str, Any])[source]¶
Bases:
MMMixin
,Embedding
Embedding layer with gloabl and local weights. The weights are interpolated using the interpolation constant
lam
. Thus, the forward pass of this layer will use the interpolated weights.Note
The global weights are the “default” weights of the class:torch.nn.Embedding layer, while the local ones are in the submodule
weight_local
.- weight_local¶
The local weights.
- Type:
torch.Tensor
class fluke.utils.model.LinesBN2d
- class fluke.utils.model.LinesBN2d(*args, **kwargs: dict[str, Any])[source]¶
Bases:
MMMixin
,BatchNorm2d
BatchNorm2d layer with gloabl and local weights. The weights are interpolated using the interpolation constant
lam
. Thus, the forward pass of this layer will use the interpolated weights.Note
The global weights are the “default” weights of the
nn.BatchNorm2d
layer, while the local ones are in the submodulesweight_local
andbias_local
.- weight_local¶
The local weights.
- Type:
torch.Tensor
- bias_local¶
The local bias.
- Type:
torch.Tensor
class fluke.utils.model.AllLayerOutputModel
- class fluke.utils.model.AllLayerOutputModel(model: Module)[source]¶
Bases:
Module
Wrapper class to get the output of all layers in a model. Once the model is wrapped with this class, the activations of all layers can be accessed through the attributes
activations_in
andactivations_out
.activations_in
is a dictionary that contains the input activations of all layers.activations_out
is a dictionary that contains the output activations of all layers.Note
The activations are stored in the order in which they are computed during the forward pass.
Important
If you need to access the activations of a specific layer after a potential activation function, you should use the
activations_in
of the next layer. For example, if you need the activations of the first layer after the ReLU activation, you should use the activations_in of the second layer. These attribute may not include the activations of the last layer if it includes an activation function.Important
If your model includes as submodule all the activations functions (e.g., of type torch.nn.ReLU), then you can use the
activations_out
attribute to get all the activations (i.e., before and after the activation functions).- model¶
The model to wrap.
- Type:
torch.nn.Module
- activations_in¶
The input activations of all layers.
- Type:
OrderedDict
- activations_out¶
The output activations of all layers.
- Type:
OrderedDict
- Parameters:
model (torch.nn.Module) – The model to wrap.
Functions¶
- fluke.utils.model.diff_model(model_dict1: dict, model_dict2: dict) OrderedDict [source]¶
Compute the difference between two model state dictionaries. The difference is computed at the level of the parameters.
- Parameters:
- Returns:
The state dictionary of the difference between the two models.
- Return type:
OrderedDict
- Raises:
AssertionError – If the two models have different architectures.
- fluke.utils.model.merge_models(model_1: Module, model_2: Module, lam: float) Module [source]¶
Merge two models using a linear interpolation. The interpolation is done at the level of the parameters using the formula:
merged_model = (1 - lam) * model_1 + lam * model_2
.- Parameters:
model_1 (torch.nn.Module) – The first model.
model_2 (torch.nn.Module) – The second model.
lam (float) – The interpolation constant.
- Returns:
The merged model.
- Return type:
Module
- fluke.utils.model.set_lambda_model(model: MMMixin, lam: float, layerwise: bool = False) None [source]¶
Set model interpolation constant.
Warning
This function performs an inplace operation on the model, and it assumes that the model has been built using the
MMMixin
classes.
- fluke.utils.model.get_output_shape(model: Module, input_dim: tuple[int, ...]) tuple[int, ...] [source]¶
Get the output shape of a model given the shape of the input.
- fluke.utils.model.get_local_model_dict(model: MMMixin) OrderedDict [source]¶
Get the local model state dictionary.
- Parameters:
model (torch.nn.Module) – the model.
- Returns:
the local model state dictionary.
- Return type:
OrderedDict
- fluke.utils.model.get_global_model_dict(model: MMMixin) OrderedDict [source]¶
Get the global model state dictionary.
- Parameters:
model (torch.nn.Module) – the model.
- Returns:
the global model state dictionary.
- Return type:
OrderedDict
- fluke.utils.model.mix_networks(global_model: Module, local_model: Module, lamda: float) MMMixin [source]¶
Mix two networks using a linear interpolation. This method takes two models and a lambda value and returns a new model that is a linear interpolation of the two input models. It transparenly handles the interpolation of the different layers of the models. The returned model implements the
MMMixin
class and has all the layers swapped with the corresponding interpolated layers.- Parameters:
global_model (torch.nn.Module) – The global model.
local_model (torch.nn.Module) – The local model.
lamda (float) – The interpolation constant.
- Returns:
The merged/interpolated model that implements the
MMMixin
class.- Return type:
Module
- fluke.utils.model.batch_norm_to_group_norm(layer: Module) Module [source]¶
Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm
- Parameters:
layer (torch.nn.Module) – model or one layer of a model.
- Returns:
model with group norm layers instead of batch norm layers.
- Return type:
torch.nn.Module
- Raises:
ValueError – If the number of channels \(\notin \{2^i\}_{i=4}^{11}\)
- fluke.utils.model.safe_load_state_dict(model1: Module, model2_state_dict: dict) None [source]¶
Load a state dictionary into a model. This function is a safe version of
model.load_state_dict
that handles the case in which the state dictionary has keys that match withSTATE_DICT_KEYS_TO_IGNORE
and thus have to be ignored.Caution
This function performs an inplace operation on
model1
.- Parameters:
model1 (torch.nn.Module) – The model to load the state dictionary.
model2_state_dict (dict) – The state dictionary.
- fluke.utils.model.check_model_fit_mem(model: Module, input_size: tuple[int, ...], num_clients: int, device: str = 'cuda', mps_default: bool = True)[source]¶
Check if the models fit in the memory of the device. The method estimates the memory usage of the models, when all clients and the server own a single neural network, on the device and checks if the models fit in the memory of the device.
Attention
This function only works for CUDA devices. For MPS devices, the function will always return the value of
mps_default
. To date, PyTorch does not provide a way to estimate the memory usage of a model on an MPS device.