What is Federated Learning?

In classical Machine Learning, there exist a model and data. Let us consider the model to be a neural network or a classical linear regression. We can train this model using our available data to perform a useful task. This task might involve object detection, converting an audio file into text, or playing games such as Go or chess by this model.

However, the data we train on might not always be owned by us or open source. Especially considering data privacy, since users’ data are not shared, these data essentially remain untapped, benefiting neither the users nor the companies. For instance, mobile phones are one of the richest sources in this regard. Location data, word data from keyboards, in addition to location and speed information from vehicles, etc., are examples.

Normally, these data would be collected at a central point, and training would be conducted as mentioned in the first paragraph. In Federated Learning, however, the approach is the complete opposite. Training occurs on local devices with their own data, and subsequently, these computations update the main model like data. In this way, both data privacy is preserved, and models are developed from inaccessible data.

For Federated Learning, the Flower Framework, which is compatible with most Deep Learning and Machine Learning Frameworks and is also fast and efficient, was used. Given that the Flower Framework, which is currently considered a state-of-the-art model in the market, also supports the HuggingFace Transformers Framework, it can be integrated simply.

Flower Client


Initially, we commence with the implementation of the Flower client. Within the client class, we first define the Train_Federated model prepared for Federated Learning as discussed in the sections 1.1 HF Transformers SAM’s Fine Tune and 1.4 Fine Tune of SAM.

example To federate our example to multiple clients, we first need to write our Flower client class (inheriting from flwr.client.NumPyClient). This is very easy, as our model is a standard PyTorch model

class SAMClient(fl.client.NumPyClient):
    """
    Flower client implementing SAM
    Args:
        dataset_root (str): Root directory of the dataset
        image_subfolder (str): Name of the image subfolder
        annotation_subfolder (str): Name of the annotation subfolder
        batch_size (int): Batch size for training
        num_epochs (int): Number of epochs for training

"""

In the SAMClient class, since the methods are defined within the training class, here we will only provide a general overview of the methods.

get_parameters

def get_parameters(self, **kwargs):
        # returns initial parameters (before training)
        return self.train_model.get_model_parameters(self.train_model.model)
  • The get_parameters function retrieves the server client’s parameters using the get_model_parameters function from the Training class.
    def get_model_parameters(self, model):
        """Get model parameters as a list of NumPy ndarrays.
        Args:
            model (nn.Module): Model to get the parameters from.
            Returns:
            list: List of NumPy ndarrays representing the parameters."""
        return [val.cpu().numpy() for _, val in model.state_dict().items()]

def set_parameters(self, parameters):
        # set model parameters received from the server
        self.train_model.set_model_parameters(self.train_model.model, parameters)
  • Conversely, the set_parameters function facilitates the transmission of trained parameters to the server, utilizing the set_model_parameters function from the Training class in a similar manner.

#Flowe

    def set_model_parameters(self, model, parameters):
        """Set model parameters.
            Args:
            model (nn.Module): Model to set the parameters for.
            parameters (list): List of NumPy ndarrays representing the parameters.
            """
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict)

Flower Server

In this section, we will create a Server for the Client we have constructed. It is necessary first to determine the strategy to be employed. Flower comes with popular ready-to-use federated learning strategies. Strategies define rules that regulate how local updates will be combined and managed during the iterative model training process. They may also include mechanisms to address communication constraints, mitigate the effects of outlier participants, and improve model convergence rates.

Research continues to focus on extending and enhancing existing strategies to address new challenges and scenarios in federated learning, including supported learning and hybrid federated dual coordinate ascent. Additionally, Flower offers the possibility to implement Custom Federated Learning strategies. However, among the general strategies used on the server side, FedAvg is the most effective, and when combined with appropriate clients for processing non-IID data in image-related tasks (for example, MAML), it can be highly effective for such tasks. This will be our strategy on the server side. The sections provided below are merely examples, yet federated training can be initiated or extended in this manner.

# strategy selection
strategy = fl.server.strategy.FedAvg(
    min_fit_clients=2,
    min_available_clients=2  
)
 ```

In this section, as mentioned in the previous paragraph, we define our strategy as FedAvg, thus employing a function accordingly using the Flower Framework. Subsequently, we specify that it will start with at least 2 clients.

```python
  # server configuration
server_config = fl.server.ServerConfig(
        num_rounds=10
    )

Here, we define the number of rounds the server will operate.

Below, we initialize our server

    # server initialization
    fl.server.start_server(
        server_address="0.0.0.0:8080",
        config=server_config,
            strategy=strategy
    )

#Flower Segment Anything Model Training

In the client section, we had described our training in the section 1.4 Fine Tune of SAM for local training. Here, certain functions exist in the Client that are not present in local training, as follows:.

  def set_model_parameters(self, model, parameters):
        """Set model parameters.
            Args:
            model (nn.Module): Model to set the parameters for.
            parameters (list): List of NumPy ndarrays representing the parameters.
            """
        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict)
  • set_model_parameters facilitates the return transmission of model parameters from the server to the clients after they are retrieved from the clients using get_model_parameters.
    def get_model_parameters(self, model):
        """Get model parameters as a list of NumPy ndarrays.
        Args:
            model (nn.Module): Model to get the parameters from.
            Returns:
            list: List of NumPy ndarrays representing the parameters."""
        return [val.cpu().numpy() for _, val in model.state_dict().items()]
  • get_model_parameters enables the collection of model parameters from the clients and their transmission to the server. The parameters transmitted to the server are computed with the strategies mentioned in the Flower_Server section and returned to the clients through set_model_parameters.

  • Returning to the Client side, the training of the model occurs as follows.

    def fit(self, parameters, config):
        # trains the model with the parameters received from the server
        updated_parameters = self.train_model.train(initial_parameters=parameters)
        return updated_parameters, len(self.train_model.train_dataloader().dataset), {}

Trained parameters are received and then returned along with the dataset’s length.

Next Page