fluke.utils.model

This submodule provides utilities for pytorch model manipulation.

Classes included in fluke.utils.model

MMMixin

Mixin class for model interpolation.

LinesLinear

Linear layer with gloabl and local weights.

LinesConv2d

Conv2d layer with gloabl and local weights.

LinesLSTM

LSTM layer with gloabl and local weights.

LinesEmbedding

Embedding layer with gloabl and local weights.

LinesBN2d

BatchNorm2d layer with gloabl and local weights.

AllLayerOutputModel

Wrapper class to get the output of all layers in a model.

Functions included in fluke.utils.model

diff_model

Compute the difference between two model state dictionaries.

merge_models

Merge two models using a linear interpolation.

set_lambda_model

Set model interpolation constant.

get_output_shape

Get the output shape of a model given the shape of the input.

get_local_model_dict

Get the local model state dictionary.

get_global_model_dict

Get the global model state dictionary.

mix_networks

Mix two networks using a linear interpolation.

batch_norm_to_group_norm

Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm

safe_load_state_dict

Load a state dictionary into a model.

check_model_fit_mem

Check if the models fit in the memory of the device.

Classes

class fluke.utils.model.MMMixin

class fluke.utils.model.MMMixin(*args, **kwargs)[source]

Mixin class for model interpolation. This class provides the necessary methods to interpolate between two models. This mixin class must be used as a parent class for the PyTorch modules that need to be interpolated.

Tip

Ideally, when using this mixin to implement a new class M, this should be mixed with a class C that extends torch.nn.Module and the interpolation of the parameters should happen between the parameters in C and a new set of parameters defined in A. This type of multiple inheritance must have as first parent the class MMMixin and as second parent a class that extends torch.nn.Module. For example:

1# C is a class that extends torch.nn.Module
2class M(MMMixin, C):
3    def __init__(self, *args, **kwargs: dict[str, Any]):
4        super().__init__(*args, **kwargs)
5        self.weight_local = nn.Parameter(torch.zeros_like(self.weight))

In this case, the default implementation of the method get_weight() will work and will interpolate between the weight and the weight_local attribute of the module.

lam

The interpolation constant.

Type:

float

get_lambda() float[source]

Get the interpolation constant.

Returns:

The interpolation constant.

Return type:

float

get_weight() Tensor[source]

Get the interpolated weights of the layer or module according to the interpolation constant lam. The default implementation assumes that the layer or module has a weight attribute and a weight_local attribute that are both tensors of the same shape. The interpolated weights are computed as: w = (1 - self.lam) * self.weight + self.lam * self.weight_local.

Returns:

The interpolated weights.

Return type:

torch.Tensor

set_lambda(lam) None[source]

Set the interpolation constant.

Parameters:

lam (float) – The interpolation constant.

class fluke.utils.model.LinesLinear

class fluke.utils.model.LinesLinear(*args, **kwargs: dict[str, Any])[source]

Bases: MMMixin, Linear

Linear layer with gloabl and local weights. The weights are interpolated using the interpolation constant lam. Thus, the forward pass of this layer will use the interpolated weights.

Note

The global weights are the “default” weights of the torch.nn.Linear layer, while the local ones are in the submodule weight_local (and bias_local).

weight_local

The local weights.

Type:

torch.Tensor

bias_local

The local bias.

Type:

torch.Tensor

class fluke.utils.model.LinesConv2d

class fluke.utils.model.LinesConv2d(*args, **kwargs: dict[str, Any])[source]

Bases: MMMixin, Conv2d

Conv2d layer with gloabl and local weights. The weights are interpolated using the interpolation constant lam. Thus, the forward pass of this layer will use the interpolated weights.

Note

The global weights are the “default” weights of the torch.nn.Conv2d layer, while the local ones are in the submodule weight_local (and bias_local).

weight_local

The local weights.

Type:

torch.Tensor

bias_local

The local bias.

Type:

torch.Tensor

class fluke.utils.model.LinesLSTM

class fluke.utils.model.LinesLSTM(*args, **kwargs: dict[str, Any])[source]

Bases: MMMixin, LSTM

LSTM layer with gloabl and local weights. The weights are interpolated using the interpolation constant lam. Thus, the forward pass of this layer will use the interpolated weights.

Note

The global weights are the “default” weights of the torch.nn.LSTM layer, while the local ones are in the submodules weight_hh_l{layer}_local and weight_ih_l{layer}_local, where layer is the layer number. Similar considerations apply to the biases.

Caution

This class may not work properly an all devices. If you encounter any issues, please open an issue in the repository.

weight_hh_l{layer}_local

The local hidden-hidden weights of layer layer.

Type:

torch.Tensor

weight_ih_l{layer}_local

The local input-hidden weights of layer layer.

Type:

torch.Tensor

bias_hh_l{layer}_local

The local hidden-hidden biases of layer layer.

Type:

torch.Tensor

bias_ih_l{layer}_local

The local input-hidden biases of layer layer.

Type:

torch.Tensor

class fluke.utils.model.LinesEmbedding

class fluke.utils.model.LinesEmbedding(*args, **kwargs: dict[str, Any])[source]

Bases: MMMixin, Embedding

Embedding layer with gloabl and local weights. The weights are interpolated using the interpolation constant lam. Thus, the forward pass of this layer will use the interpolated weights.

Note

The global weights are the “default” weights of the class:torch.nn.Embedding layer, while the local ones are in the submodule weight_local.

weight_local

The local weights.

Type:

torch.Tensor

class fluke.utils.model.LinesBN2d

class fluke.utils.model.LinesBN2d(*args, **kwargs: dict[str, Any])[source]

Bases: MMMixin, BatchNorm2d

BatchNorm2d layer with gloabl and local weights. The weights are interpolated using the interpolation constant lam. Thus, the forward pass of this layer will use the interpolated weights.

Note

The global weights are the “default” weights of the nn.BatchNorm2d layer, while the local ones are in the submodules weight_local and bias_local.

weight_local

The local weights.

Type:

torch.Tensor

bias_local

The local bias.

Type:

torch.Tensor

class fluke.utils.model.AllLayerOutputModel

class fluke.utils.model.AllLayerOutputModel(model: Module)[source]

Bases: Module

Wrapper class to get the output of all layers in a model. Once the model is wrapped with this class, the activations of all layers can be accessed through the attributes activations_in and activations_out.

activations_in is a dictionary that contains the input activations of all layers. activations_out is a dictionary that contains the output activations of all layers.

Note

The activations are stored in the order in which they are computed during the forward pass.

Important

If you need to access the activations of a specific layer after a potential activation function, you should use the activations_in of the next layer. For example, if you need the activations of the first layer after the ReLU activation, you should use the activations_in of the second layer. These attribute may not include the activations of the last layer if it includes an activation function.

Important

If your model includes as submodule all the activations functions (e.g., of type torch.nn.ReLU), then you can use the activations_out attribute to get all the activations (i.e., before and after the activation functions).

model

The model to wrap.

Type:

torch.nn.Module

activations_in

The input activations of all layers.

Type:

OrderedDict

activations_out

The output activations of all layers.

Type:

OrderedDict

Parameters:

model (torch.nn.Module) – The model to wrap.

activate() None[source]

Activate the all layer output functionality.

deactivate(clear_activations: bool = True) None[source]

Deactivate the all layer output functionality.

is_active() bool[source]

Returns whether the all layer output model is active.

Returns:

Whether the all layer output model is active.

Return type:

bool

Functions

fluke.utils.model.diff_model(model_dict1: dict, model_dict2: dict) OrderedDict[source]

Compute the difference between two model state dictionaries. The difference is computed at the level of the parameters.

Parameters:
  • model_dict1 (dict) – The state dictionary of the first model.

  • model_dict2 (dict) – The state dictionary of the second model.

Returns:

The state dictionary of the difference between the two models.

Return type:

OrderedDict

Raises:

AssertionError – If the two models have different architectures.

fluke.utils.model.merge_models(model_1: Module, model_2: Module, lam: float) Module[source]

Merge two models using a linear interpolation. The interpolation is done at the level of the parameters using the formula: merged_model = (1 - lam) * model_1 + lam * model_2.

Parameters:
  • model_1 (torch.nn.Module) – The first model.

  • model_2 (torch.nn.Module) – The second model.

  • lam (float) – The interpolation constant.

Returns:

The merged model.

Return type:

Module

fluke.utils.model.set_lambda_model(model: MMMixin, lam: float, layerwise: bool = False) None[source]

Set model interpolation constant.

Warning

This function performs an inplace operation on the model, and it assumes that the model has been built using the MMMixin classes.

Parameters:
  • model (torch.nn.Module) – model

  • lam (float) – constant used for interpolation (0 means a retrieval of a global model, 1 means a retrieval of a local model)

  • layerwise (bool) – set different lambda layerwise or not

fluke.utils.model.get_output_shape(model: Module, input_dim: tuple[int, ...]) tuple[int, ...][source]

Get the output shape of a model given the shape of the input.

Parameters:
  • model (torch.nn.Module) – The model to get the output shape.

  • input_dim (tuple[int, ...]) – The expected input shape of the model.

Returns:

The output shape of the model.

Return type:

tuple[int, …]

fluke.utils.model.get_local_model_dict(model: MMMixin) OrderedDict[source]

Get the local model state dictionary.

Parameters:

model (torch.nn.Module) – the model.

Returns:

the local model state dictionary.

Return type:

OrderedDict

fluke.utils.model.get_global_model_dict(model: MMMixin) OrderedDict[source]

Get the global model state dictionary.

Parameters:

model (torch.nn.Module) – the model.

Returns:

the global model state dictionary.

Return type:

OrderedDict

fluke.utils.model.mix_networks(global_model: Module, local_model: Module, lamda: float) MMMixin[source]

Mix two networks using a linear interpolation. This method takes two models and a lambda value and returns a new model that is a linear interpolation of the two input models. It transparenly handles the interpolation of the different layers of the models. The returned model implements the MMMixin class and has all the layers swapped with the corresponding interpolated layers.

Parameters:
  • global_model (torch.nn.Module) – The global model.

  • local_model (torch.nn.Module) – The local model.

  • lamda (float) – The interpolation constant.

Returns:

The merged/interpolated model that implements the MMMixin class.

Return type:

Module

fluke.utils.model.batch_norm_to_group_norm(layer: Module) Module[source]

Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm

Parameters:

layer (torch.nn.Module) – model or one layer of a model.

Returns:

model with group norm layers instead of batch norm layers.

Return type:

torch.nn.Module

Raises:

ValueError – If the number of channels \(\notin \{2^i\}_{i=4}^{11}\)

fluke.utils.model.safe_load_state_dict(model1: Module, model2_state_dict: dict) None[source]

Load a state dictionary into a model. This function is a safe version of model.load_state_dict that handles the case in which the state dictionary has keys that match with STATE_DICT_KEYS_TO_IGNORE and thus have to be ignored.

Caution

This function performs an inplace operation on model1.

Parameters:
  • model1 (torch.nn.Module) – The model to load the state dictionary.

  • model2_state_dict (dict) – The state dictionary.

fluke.utils.model.check_model_fit_mem(model: Module, input_size: tuple[int, ...], num_clients: int, device: str = 'cuda', mps_default: bool = True)[source]

Check if the models fit in the memory of the device. The method estimates the memory usage of the models, when all clients and the server own a single neural network, on the device and checks if the models fit in the memory of the device.

Attention

This function only works for CUDA devices. For MPS devices, the function will always return the value of mps_default. To date, PyTorch does not provide a way to estimate the memory usage of a model on an MPS device.

Parameters:
  • model (torch.nn.Module) – The model to check.

  • input_size (tuple[int, ...]) – The input size of the model.

  • num_clients (int) – The number of clients in the federation.

  • device (str, optional) – The device to check. Defaults to ‘cuda’.

  • mps_default (bool, optional) – The default value to return if the device is MPS.